├── README.md ├── config └── config.py ├── data_processing.py ├── dataset ├── dataset_DINet_clip.py └── dataset_DINet_frame.py ├── inference.py ├── models ├── DINet.py ├── Discriminator.py ├── Syncnet.py ├── VGG19.py └── old │ ├── Syncnet_BN.py │ └── Syncnet_halfBN.py ├── requirements.txt ├── sync_batchnorm ├── __init__.py ├── batchnorm.py ├── batchnorm_reimpl.py ├── comm.py ├── replicate.py └── unittest.py ├── train_DINet_clip.py ├── train_DINet_frame.py └── utils ├── data_processing.py ├── deep_speech.py └── training_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # DINet: Deformation Inpainting Network for Realistic Face Visually Dubbing on High Resolution Video (AAAI2023) 2 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/178c6b3ec0074af7a2dcc9ef26450e75.png) 3 | [Paper](https://fuxivirtualhuman.github.io/pdf/AAAI2023_FaceDubbing.pdf)         [demo video](https://www.youtube.com/watch?v=UU344T-9h7M&t=6s)      Supplementary materials 4 | 5 | ## Inference 6 | ##### Download resources (asserts.zip) in [Google drive](https://drive.google.com/drive/folders/1rPtOo9Uuhc59YfFVv4gBmkh0_oG0nCQb?usp=share_link). unzip and put dir in ./. 7 | + Inference with example videos. Run 8 | ```python 9 | python inference.py --mouth_region_size=256 --source_video_path=./asserts/examples/testxxx.mp4 --source_openface_landmark_path=./asserts/examples/testxxx.csv --driving_audio_path=./asserts/examples/driving_audio_xxx.wav --pretrained_clip_DINet_path=./asserts/clip_training_DINet_256mouth.pth 10 | ``` 11 | The results are saved in ./asserts/inference_result 12 | 13 | + Inference with custom videos. 14 | **Note:** The released pretrained model is trained on HDTF dataset with 363 training videos (video names are in ./asserts/training_video_name.txt), so the generalization is limited. It would be better to test custom videos with normal lighting, frontal view etc.(see the limitation section in the paper). **We also release the training code**, so if a larger high resolution audio-visual dataset is proposed in the further, you can use the training code to train a model with greater generalization. Besides, we release coarse-to-fine training strategy, **so you can use the training code to train a model in arbitrary resolution** (larger than 416x320 if gpu memory and training dataset are available). 15 | 16 | Using [openface](https://github.com/TadasBaltrusaitis/OpenFace) to detect smooth facial landmarks of your custom video. We run the **OpenFaceOffline.exe** on windows 10 system with this setting: 17 | 18 | | Record | Recording settings | OpenFace setting | View | Face Detector | Landmark Detector | 19 | |--|--|--|--|--|--| 20 | | 2D landmark & tracked videos | Mask aligned image | Use dynamic AU models | Show video | Openface (MTCNN)| CE-CLM | 21 | 22 | The detected facial landmarks are saved in "xxxx.csv". Run 23 | ```python 24 | python inference.py --mouth_region_size=256 --source_video_path= custom video path --source_openface_landmark_path= detected landmark path --driving_audio_path= driving audio path --pretrained_clip_DINet_path=./asserts/clip_training_DINet_256mouth.pth 25 | ``` 26 | to realize face visually dubbing on your custom videos. 27 | ## Training 28 | ### Data Processing 29 | We release the code of video processing on [HDTF dataset](https://github.com/MRzzm/HDTF). You can also use this code to process custom videos. 30 | 31 | 1. Downloading videos from [HDTF dataset](https://github.com/MRzzm/HDTF). Splitting videos according to xx_annotion_time.txt and **do not** crop&resize videos. 32 | 2. Resampling all split videos into **25fps** and put videos into "./asserts/split_video_25fps". You can see the two example videos in "./asserts/split_video_25fps". We use [software](http://www.pcfreetime.com/formatfactory/cn/index.html) to resample videos. We provide the name list of training videos in our experiment. (pls see "./asserts/training_video_name.txt") 33 | 3. Using [openface](https://github.com/TadasBaltrusaitis/OpenFace) to detect smooth facial landmarks of all videos. Putting all ".csv" results into "./asserts/split_video_25fps_landmark_openface". You can see the two example csv files in "./asserts/split_video_25fps_landmark_openface". 34 | 35 | 4. Extracting frames from all videos and saving frames in "./asserts/split_video_25fps_frame". Run 36 | ```python 37 | python data_processing.py --extract_video_frame 38 | ``` 39 | 5. Extracting audios from all videos and saving audios in "./asserts/split_video_25fps_audio". Run 40 | ```python 41 | python data_processing.py --extract_audio 42 | ``` 43 | 6. Extracting deepspeech features from all audios and saving features in "./asserts/split_video_25fps_deepspeech". Run 44 | ```python 45 | python data_processing.py --extract_deep_speech 46 | ``` 47 | 7. Cropping faces from all videos and saving images in "./asserts/split_video_25fps_crop_face". Run 48 | ```python 49 | python data_processing.py --crop_face 50 | ``` 51 | 8. Generating training json file "./asserts/training_json.json". Run 52 | ```python 53 | python data_processing.py --generate_training_json 54 | ``` 55 | 56 | ### Training models 57 | We split the training process into **frame training stage** and **clip training stage**. In frame training stage, we use coarse-to-fine strategy, **so you can train the model in arbitrary resolution**. 58 | 59 | #### Frame training stage. 60 | In frame training stage, we only use perception loss and GAN loss. 61 | 62 | 1. Firstly, train the DINet in 104x80 (mouth region is 64x64) resolution. Run 63 | ```python 64 | python train_DINet_frame.py --augment_num=32 --mouth_region_size=64 --batch_size=24 --result_path=./asserts/training_model_weight/frame_training_64 65 | ``` 66 | You can stop the training when the loss converges (we stop in about 270 epoch). 67 | 68 | 2. Loading the pretrained model (face:104x80 & mouth:64x64) and train the DINet in higher resolution (face:208x160 & mouth:128x128). Run 69 | ```python 70 | python train_DINet_frame.py --augment_num=100 --mouth_region_size=128 --batch_size=80 --coarse2fine --coarse_model_path=./asserts/training_model_weight/frame_training_64/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_128 71 | ``` 72 | You can stop the training when the loss converges (we stop in about 200 epoch). 73 | 74 | 3. Loading the pretrained model (face:208x160 & mouth:128x128) and train the DINet in higher resolution (face:416x320 & mouth:256x256). Run 75 | ```python 76 | python train_DINet_frame.py --augment_num=20 --mouth_region_size=256 --batch_size=12 --coarse2fine --coarse_model_path=./asserts/training_model_weight/frame_training_128/xxxxxx.pth --result_path=./asserts/training_model_weight/frame_training_256 77 | ``` 78 | You can stop the training when the loss converges (we stop in about 200 epoch). 79 | 80 | #### Clip training stage. 81 | In clip training stage, we use perception loss, frame/clip GAN loss and sync loss. Loading the pretrained frame model (face:416x320 & mouth:256x256), pretrained syncnet model (mouth:256x256) and train the DINet in clip setting. Run 82 | ```python 83 | python train_DINet_clip.py --augment_num=3 --mouth_region_size=256 --batch_size=3 --pretrained_syncnet_path=./asserts/syncnet_256mouth.pth --pretrained_frame_DINet_path=./asserts/training_model_weight/frame_training_256/xxxxx.pth --result_path=./asserts/training_model_weight/clip_training_256 84 | ``` 85 | You can stop the training when the loss converges and select the best model (our best model is at 160 epoch). 86 | 87 | ## Acknowledge 88 | The AdaAT is borrowed from [AdaAT](https://github.com/MRzzm/AdaAT). The deepspeech feature is borrowed from [AD-NeRF](https://github.com/YudongGuo/AD-NeRF). The basic module is borrowed from [first-order](https://github.com/AliaksandrSiarohin/first-order-model). Thanks for their released code. -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class DataProcessingOptions(): 4 | def __init__(self): 5 | self.parser = argparse.ArgumentParser() 6 | 7 | def parse_args(self): 8 | self.parser.add_argument('--extract_video_frame', action='store_true', help='extract video frame') 9 | self.parser.add_argument('--extract_audio', action='store_true', help='extract audio files from videos') 10 | self.parser.add_argument('--extract_deep_speech', action='store_true', help='extract deep speech features') 11 | self.parser.add_argument('--crop_face', action='store_true', help='crop face') 12 | self.parser.add_argument('--generate_training_json', action='store_true', help='generate training json file') 13 | 14 | self.parser.add_argument('--source_video_dir', type=str, default="./asserts/training_data/split_video_25fps", 15 | help='path of source video in 25 fps') 16 | self.parser.add_argument('--openface_landmark_dir', type=str, default="./asserts/training_data/split_video_25fps_landmark_openface", 17 | help='path of openface landmark dir') 18 | self.parser.add_argument('--video_frame_dir', type=str, default="./asserts/training_data/split_video_25fps_frame", 19 | help='path of video frames') 20 | self.parser.add_argument('--audio_dir', type=str, default="./asserts/training_data/split_video_25fps_audio", 21 | help='path of audios') 22 | self.parser.add_argument('--deep_speech_dir', type=str, default="./asserts/training_data/split_video_25fps_deepspeech", 23 | help='path of deep speech') 24 | self.parser.add_argument('--crop_face_dir', type=str, default="./asserts/training_data/split_video_25fps_crop_face", 25 | help='path of crop face dir') 26 | self.parser.add_argument('--json_path', type=str, default="./asserts/training_data/training_json.json", 27 | help='path of training json') 28 | self.parser.add_argument('--clip_length', type=int, default=9, help='clip length') 29 | self.parser.add_argument('--deep_speech_model', type=str, default="./asserts/output_graph.pb", 30 | help='path of pretrained deepspeech model') 31 | return self.parser.parse_args() 32 | 33 | class DINetTrainingOptions(): 34 | def __init__(self): 35 | self.parser = argparse.ArgumentParser() 36 | 37 | def parse_args(self): 38 | self.parser.add_argument('--seed', type=int, default=456, help='random seed to use.') 39 | self.parser.add_argument('--source_channel', type=int, default=3, help='input source image channels') 40 | self.parser.add_argument('--ref_channel', type=int, default=15, help='input reference image channels') 41 | self.parser.add_argument('--audio_channel', type=int, default=29, help='input audio channels') 42 | self.parser.add_argument('--augment_num', type=int, default=32, help='augment training data') 43 | self.parser.add_argument('--mouth_region_size', type=int, default=64, help='augment training data') 44 | self.parser.add_argument('--train_data', type=str, default=r"./asserts/training_data/training_json.json", 45 | help='path of training json') 46 | self.parser.add_argument('--batch_size', type=int, default=24, help='training batch size') 47 | self.parser.add_argument('--lamb_perception', type=int, default=10, help='weight of perception loss') 48 | self.parser.add_argument('--lamb_syncnet_perception', type=int, default=0.1, help='weight of perception loss') 49 | self.parser.add_argument('--lr_g', type=float, default=0.0001, help='initial learning rate for adam') 50 | self.parser.add_argument('--lr_dI', type=float, default=0.0001, help='initial learning rate for adam') 51 | self.parser.add_argument('--start_epoch', default=1, type=int, help='start epoch in training stage') 52 | self.parser.add_argument('--non_decay', default=200, type=int, help='num of epoches with fixed learning rate') 53 | self.parser.add_argument('--decay', default=200, type=int, help='num of linearly decay epochs') 54 | self.parser.add_argument('--checkpoint', type=int, default=2, help='num of checkpoints in training stage') 55 | self.parser.add_argument('--result_path', type=str, default=r"./asserts/training_model_weight/frame_training_64", 56 | help='result path to save model') 57 | self.parser.add_argument('--coarse2fine', action='store_true', help='If true, load pretrained model path.') 58 | self.parser.add_argument('--coarse_model_path', 59 | default='', 60 | type=str, 61 | help='Save data (.pth) of previous training') 62 | self.parser.add_argument('--pretrained_syncnet_path', 63 | default='', 64 | type=str, 65 | help='Save data (.pth) of pretrained syncnet') 66 | self.parser.add_argument('--pretrained_frame_DINet_path', 67 | default='', 68 | type=str, 69 | help='Save data (.pth) of frame trained DINet') 70 | # ========================= Discriminator ========================== 71 | self.parser.add_argument('--D_num_blocks', type=int, default=4, help='num of down blocks in discriminator') 72 | self.parser.add_argument('--D_block_expansion', type=int, default=64, help='block expansion in discriminator') 73 | self.parser.add_argument('--D_max_features', type=int, default=256, help='max channels in discriminator') 74 | return self.parser.parse_args() 75 | 76 | 77 | class DINetInferenceOptions(): 78 | def __init__(self): 79 | self.parser = argparse.ArgumentParser() 80 | 81 | def parse_args(self): 82 | self.parser.add_argument('--source_channel', type=int, default=3, help='channels of source image') 83 | self.parser.add_argument('--ref_channel', type=int, default=15, help='channels of reference image') 84 | self.parser.add_argument('--audio_channel', type=int, default=29, help='channels of audio feature') 85 | self.parser.add_argument('--mouth_region_size', type=int, default=256, help='help to resize window') 86 | self.parser.add_argument('--source_video_path', 87 | default='./asserts/examples/test4.mp4', 88 | type=str, 89 | help='path of source video') 90 | self.parser.add_argument('--source_openface_landmark_path', 91 | default='./asserts/examples/test4.csv', 92 | type=str, 93 | help='path of detected openface landmark') 94 | self.parser.add_argument('--driving_audio_path', 95 | default='./asserts/examples/driving_audio_1.wav', 96 | type=str, 97 | help='path of driving audio') 98 | self.parser.add_argument('--pretrained_clip_DINet_path', 99 | default='./asserts/clip_training_DINet_256mouth.pth', 100 | type=str, 101 | help='pretrained model of DINet(clip trained)') 102 | self.parser.add_argument('--deepspeech_model_path', 103 | default='./asserts/output_graph.pb', 104 | type=str, 105 | help='path of deepspeech model') 106 | self.parser.add_argument('--res_video_dir', 107 | default='./asserts/inference_result', 108 | type=str, 109 | help='path of generated videos') 110 | return self.parser.parse_args() -------------------------------------------------------------------------------- /data_processing.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import subprocess 4 | import cv2 5 | import numpy as np 6 | import json 7 | 8 | from utils.data_processing import load_landmark_openface,compute_crop_radius 9 | from utils.deep_speech import DeepSpeech 10 | from config.config import DataProcessingOptions 11 | 12 | def extract_audio(source_video_dir,res_audio_dir): 13 | ''' 14 | extract audio files from videos 15 | ''' 16 | if not os.path.exists(source_video_dir): 17 | raise ('wrong path of video dir') 18 | if not os.path.exists(res_audio_dir): 19 | os.mkdir(res_audio_dir) 20 | video_path_list = glob.glob(os.path.join(source_video_dir, '*.mp4')) 21 | for video_path in video_path_list: 22 | print('extract audio from video: {}'.format(os.path.basename(video_path))) 23 | audio_path = os.path.join(res_audio_dir, os.path.basename(video_path).replace('.mp4', '.wav')) 24 | cmd = 'ffmpeg -i {} -f wav -ar 16000 {}'.format(video_path, audio_path) 25 | subprocess.call(cmd, shell=True) 26 | 27 | def extract_deep_speech(audio_dir,res_deep_speech_dir,deep_speech_model_path): 28 | ''' 29 | extract deep speech feature 30 | ''' 31 | if not os.path.exists(res_deep_speech_dir): 32 | os.mkdir(res_deep_speech_dir) 33 | DSModel = DeepSpeech(deep_speech_model_path) 34 | wav_path_list = glob.glob(os.path.join(audio_dir, '*.wav')) 35 | for wav_path in wav_path_list: 36 | video_name = os.path.basename(wav_path).replace('.wav', '') 37 | res_dp_path = os.path.join(res_deep_speech_dir, video_name + '_deepspeech.txt') 38 | if os.path.exists(res_dp_path): 39 | os.remove(res_dp_path) 40 | print('extract deep speech feature from audio:{}'.format(video_name)) 41 | ds_feature = DSModel.compute_audio_feature(wav_path) 42 | np.savetxt(res_dp_path, ds_feature) 43 | 44 | def extract_video_frame(source_video_dir,res_video_frame_dir): 45 | ''' 46 | extract video frames from videos 47 | ''' 48 | if not os.path.exists(source_video_dir): 49 | raise ('wrong path of video dir') 50 | if not os.path.exists(res_video_frame_dir): 51 | os.mkdir(res_video_frame_dir) 52 | video_path_list = glob.glob(os.path.join(source_video_dir, '*.mp4')) 53 | for video_path in video_path_list: 54 | video_name = os.path.basename(video_path) 55 | frame_dir = os.path.join(res_video_frame_dir, video_name.replace('.mp4', '')) 56 | if not os.path.exists(frame_dir): 57 | os.makedirs(frame_dir) 58 | print('extracting frames from {} ...'.format(video_name)) 59 | videoCapture = cv2.VideoCapture(video_path) 60 | fps = videoCapture.get(cv2.CAP_PROP_FPS) 61 | if int(fps) != 25: 62 | raise ('{} video is not in 25 fps'.format(video_path)) 63 | frames = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) 64 | for i in range(int(frames)): 65 | ret, frame = videoCapture.read() 66 | result_path = os.path.join(frame_dir, str(i).zfill(6) + '.jpg') 67 | cv2.imwrite(result_path, frame) 68 | 69 | 70 | def crop_face_according_openfaceLM(openface_landmark_dir,video_frame_dir,res_crop_face_dir,clip_length): 71 | ''' 72 | crop face according to openface landmark 73 | ''' 74 | if not os.path.exists(openface_landmark_dir): 75 | raise ('wrong path of openface landmark dir') 76 | if not os.path.exists(video_frame_dir): 77 | raise ('wrong path of video frame dir') 78 | if not os.path.exists(res_crop_face_dir): 79 | os.mkdir(res_crop_face_dir) 80 | landmark_openface_path_list = glob.glob(os.path.join(openface_landmark_dir, '*.csv')) 81 | for landmark_openface_path in landmark_openface_path_list: 82 | video_name = os.path.basename(landmark_openface_path).replace('.csv', '') 83 | crop_face_video_dir = os.path.join(res_crop_face_dir, video_name) 84 | if not os.path.exists(crop_face_video_dir): 85 | os.makedirs(crop_face_video_dir) 86 | print('cropping face from video: {} ...'.format(video_name)) 87 | landmark_openface_data = load_landmark_openface(landmark_openface_path).astype(np.int) 88 | frame_dir = os.path.join(video_frame_dir, video_name) 89 | if not os.path.exists(frame_dir): 90 | raise ('run last step to extract video frame') 91 | if len(glob.glob(os.path.join(frame_dir, '*.jpg'))) != landmark_openface_data.shape[0]: 92 | raise ('landmark length is different from frame length') 93 | frame_length = min(len(glob.glob(os.path.join(frame_dir, '*.jpg'))), landmark_openface_data.shape[0]) 94 | end_frame_index = list(range(clip_length, frame_length, clip_length)) 95 | video_clip_num = len(end_frame_index) 96 | for i in range(video_clip_num): 97 | first_image = cv2.imread(os.path.join(frame_dir, '000000.jpg')) 98 | video_h,video_w = first_image.shape[0], first_image.shape[1] 99 | crop_flag, radius_clip = compute_crop_radius((video_w,video_h), 100 | landmark_openface_data[end_frame_index[i] - clip_length:end_frame_index[i], :,:]) 101 | if not crop_flag: 102 | continue 103 | radius_clip_1_4 = radius_clip // 4 104 | print('cropping {}/{} clip from video:{}'.format(i, video_clip_num, video_name)) 105 | res_face_clip_dir = os.path.join(crop_face_video_dir, str(i).zfill(6)) 106 | if not os.path.exists(res_face_clip_dir): 107 | os.mkdir(res_face_clip_dir) 108 | for frame_index in range(end_frame_index[i]- clip_length,end_frame_index[i]): 109 | source_frame_path = os.path.join(frame_dir,str(frame_index).zfill(6)+'.jpg') 110 | source_frame_data = cv2.imread(source_frame_path) 111 | frame_landmark = landmark_openface_data[frame_index, :, :] 112 | crop_face_data = source_frame_data[ 113 | frame_landmark[29, 1] - radius_clip:frame_landmark[ 114 | 29, 1] + radius_clip * 2 + radius_clip_1_4, 115 | frame_landmark[33, 0] - radius_clip - radius_clip_1_4:frame_landmark[ 116 | 33, 0] + radius_clip + radius_clip_1_4, 117 | :].copy() 118 | res_crop_face_frame_path = os.path.join(res_face_clip_dir, str(frame_index).zfill(6) + '.jpg') 119 | if os.path.exists(res_crop_face_frame_path): 120 | os.remove(res_crop_face_frame_path) 121 | cv2.imwrite(res_crop_face_frame_path, crop_face_data) 122 | 123 | 124 | def generate_training_json(crop_face_dir,deep_speech_dir,clip_length,res_json_path): 125 | video_name_list = os.listdir(crop_face_dir) 126 | video_name_list.sort() 127 | res_data_dic = {} 128 | for video_index, video_name in enumerate(video_name_list): 129 | print('generate training json file :{} {}/{}'.format(video_name,video_index,len(video_name_list))) 130 | tem_dic = {} 131 | deep_speech_feature_path = os.path.join(deep_speech_dir, video_name + '_deepspeech.txt') 132 | if not os.path.exists(deep_speech_feature_path): 133 | raise ('wrong path of deep speech') 134 | deep_speech_feature = np.loadtxt(deep_speech_feature_path) 135 | video_clip_dir = os.path.join(crop_face_dir, video_name) 136 | clip_name_list = os.listdir(video_clip_dir) 137 | clip_name_list.sort() 138 | video_clip_num = len(clip_name_list) 139 | clip_data_list = [] 140 | for clip_index, clip_name in enumerate(clip_name_list): 141 | tem_tem_dic = {} 142 | clip_frame_dir = os.path.join(video_clip_dir, clip_name) 143 | frame_path_list = glob.glob(os.path.join(clip_frame_dir, '*.jpg')) 144 | frame_path_list.sort() 145 | assert len(frame_path_list) == clip_length 146 | start_index = int(float(clip_name) * clip_length) 147 | assert int(float(os.path.basename(frame_path_list[0]).replace('.jpg', ''))) == start_index 148 | frame_name_list = [video_name + '/' + clip_name + '/' + os.path.basename(item) for item in frame_path_list] 149 | deep_speech_list = deep_speech_feature[start_index:start_index + clip_length, :].tolist() 150 | if len(frame_name_list) != len(deep_speech_list): 151 | print(' skip video: {}:{}/{} clip:{}:{}/{} because of different length: {} {}'.format( 152 | video_name,video_index,len(video_name_list),clip_name,clip_index,len(clip_name_list), 153 | len(frame_name_list),len(deep_speech_list))) 154 | tem_tem_dic['frame_name_list'] = frame_name_list 155 | tem_tem_dic['frame_path_list'] = frame_path_list 156 | tem_tem_dic['deep_speech_list'] = deep_speech_list 157 | clip_data_list.append(tem_tem_dic) 158 | tem_dic['video_clip_num'] = video_clip_num 159 | tem_dic['clip_data_list'] = clip_data_list 160 | res_data_dic[video_name] = tem_dic 161 | if os.path.exists(res_json_path): 162 | os.remove(res_json_path) 163 | with open(res_json_path,'w') as f: 164 | json.dump(res_data_dic,f) 165 | 166 | 167 | if __name__ == '__main__': 168 | opt = DataProcessingOptions().parse_args() 169 | ########## step1: extract video frames 170 | if opt.extract_video_frame: 171 | extract_video_frame(opt.source_video_dir, opt.video_frame_dir) 172 | ########## step2: extract audio files 173 | if opt.extract_audio: 174 | extract_audio(opt.source_video_dir,opt.audio_dir) 175 | ########## step3: extract deep speech features 176 | if opt.extract_deep_speech: 177 | extract_deep_speech(opt.audio_dir, opt.deep_speech_dir,opt.deep_speech_model) 178 | ########## step4: crop face images 179 | if opt.crop_face: 180 | crop_face_according_openfaceLM(opt.openface_landmark_dir,opt.video_frame_dir,opt.crop_face_dir,opt.clip_length) 181 | ########## step5: generate training json file 182 | if opt.generate_training_json: 183 | generate_training_json(opt.crop_face_dir,opt.deep_speech_dir,opt.clip_length,opt.json_path) 184 | 185 | 186 | -------------------------------------------------------------------------------- /dataset/dataset_DINet_clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | import random 5 | import cv2 6 | 7 | from torch.utils.data import Dataset 8 | 9 | 10 | def get_data(json_name,augment_num): 11 | print('start loading data') 12 | with open(json_name,'r') as f: 13 | data_dic = json.load(f) 14 | data_dic_name_list = [] 15 | for augment_index in range(augment_num): 16 | for video_name in data_dic.keys(): 17 | data_dic_name_list.append(video_name) 18 | random.shuffle(data_dic_name_list) 19 | print('finish loading') 20 | return data_dic_name_list,data_dic 21 | 22 | 23 | class DINetDataset(Dataset): 24 | def __init__(self,path_json,augment_num,mouth_region_size): 25 | super(DINetDataset, self).__init__() 26 | self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num) 27 | self.mouth_region_size = mouth_region_size 28 | self.radius = mouth_region_size//2 29 | self.radius_1_4 = self.radius//4 30 | self.img_h = self.radius * 3 + self.radius_1_4 31 | self.img_w = self.radius * 2 + self.radius_1_4 * 2 32 | self.length = len(self.data_dic_name_list) 33 | 34 | def __getitem__(self, index): 35 | video_name = self.data_dic_name_list[index] 36 | video_clip_num = len(self.data_dic[video_name]['clip_data_list']) 37 | source_anchor = random.sample(range(video_clip_num), 1)[0] 38 | source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list'] 39 | source_clip_list = [] 40 | source_clip_mask_list = [] 41 | deep_speech_list = [] 42 | reference_clip_list = [] 43 | for source_frame_index in range(2, 2 + 5): 44 | ## load source clip 45 | source_image_data = cv2.imread(source_image_path_list[source_frame_index])[:, :, ::-1] 46 | source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h)) / 255.0 47 | source_clip_list.append(source_image_data) 48 | source_image_mask = source_image_data.copy() 49 | source_image_mask[self.radius:self.radius + self.mouth_region_size, 50 | self.radius_1_4:self.radius_1_4 + self.mouth_region_size, :] = 0 51 | source_clip_mask_list.append(source_image_mask) 52 | 53 | ## load deep speech feature 54 | deepspeech_array = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][ 55 | source_frame_index - 2:source_frame_index + 3]) 56 | deep_speech_list.append(deepspeech_array) 57 | 58 | ## ## load reference images 59 | reference_frame_list = [] 60 | reference_anchor_list = random.sample(range(video_clip_num), 5) 61 | for reference_anchor in reference_anchor_list: 62 | reference_frame_path_list = self.data_dic[video_name]['clip_data_list'][reference_anchor][ 63 | 'frame_path_list'] 64 | reference_random_index = random.sample(range(9), 1)[0] 65 | reference_frame_path = reference_frame_path_list[reference_random_index] 66 | reference_frame_data = cv2.imread(reference_frame_path)[:, :, ::-1] 67 | reference_frame_data = cv2.resize(reference_frame_data, (self.img_w, self.img_h)) / 255.0 68 | reference_frame_list.append(reference_frame_data) 69 | reference_clip_list.append(np.concatenate(reference_frame_list, 2)) 70 | 71 | source_clip = np.stack(source_clip_list, 0) 72 | source_clip_mask = np.stack(source_clip_mask_list, 0) 73 | deep_speech_clip = np.stack(deep_speech_list, 0) 74 | reference_clip = np.stack(reference_clip_list, 0) 75 | deep_speech_full = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list']) 76 | 77 | # # display data 78 | # display_source = np.concatenate(source_clip_list,1) 79 | # display_source_mask = np.concatenate(source_clip_mask_list,1) 80 | # display_reference0 = np.concatenate([reference_clip_list[0][:,:,:3],reference_clip_list[0][:,:,3:6],reference_clip_list[0][:,:,6:9], 81 | # reference_clip_list[0][:,:,9:12],reference_clip_list[0][:,:,12:15]],1) 82 | # display_reference1 = np.concatenate([reference_clip_list[1][:, :, :3], reference_clip_list[1][:, :, 3:6], 83 | # reference_clip_list[1][:, :, 6:9], 84 | # reference_clip_list[1][:, :, 9:12], reference_clip_list[1][:, :, 12:15]],1) 85 | # display_reference2 = np.concatenate([reference_clip_list[2][:, :, :3], reference_clip_list[2][:, :, 3:6], 86 | # reference_clip_list[2][:, :, 6:9], 87 | # reference_clip_list[2][:, :, 9:12], reference_clip_list[2][:, :, 12:15]],1) 88 | # display_reference3 = np.concatenate([reference_clip_list[3][:, :, :3], reference_clip_list[3][:, :, 3:6], 89 | # reference_clip_list[3][:, :, 6:9], 90 | # reference_clip_list[3][:, :, 9:12], reference_clip_list[3][:, :, 12:15]],1) 91 | # display_reference4 = np.concatenate([reference_clip_list[4][:, :, :3], reference_clip_list[4][:, :, 3:6], 92 | # reference_clip_list[4][:, :, 6:9], 93 | # reference_clip_list[4][:, :, 9:12], reference_clip_list[4][:, :, 12:15]],1) 94 | # merge_img = np.concatenate([display_source,display_source_mask, 95 | # display_reference0,display_reference1,display_reference2,display_reference3, 96 | # display_reference4],0) 97 | # cv2.imshow('test',(merge_img[:,:,::-1] * 255).astype(np.uint8)) 98 | # cv2.waitKey(-1) 99 | 100 | 101 | 102 | # # 2 tensor 103 | source_clip = torch.from_numpy(source_clip).float().permute(0, 3, 1, 2) 104 | source_clip_mask = torch.from_numpy(source_clip_mask).float().permute(0, 3, 1, 2) 105 | reference_clip = torch.from_numpy(reference_clip).float().permute(0, 3, 1, 2) 106 | deep_speech_clip = torch.from_numpy(deep_speech_clip).float().permute(0, 2, 1) 107 | deep_speech_full = torch.from_numpy(deep_speech_full).permute(1, 0) 108 | return source_clip,source_clip_mask, reference_clip,deep_speech_clip,deep_speech_full 109 | 110 | def __len__(self): 111 | return self.length 112 | -------------------------------------------------------------------------------- /dataset/dataset_DINet_frame.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | import random 5 | import cv2 6 | 7 | from torch.utils.data import Dataset 8 | 9 | 10 | def get_data(json_name,augment_num): 11 | print('start loading data') 12 | with open(json_name,'r') as f: 13 | data_dic = json.load(f) 14 | data_dic_name_list = [] 15 | for augment_index in range(augment_num): 16 | for video_name in data_dic.keys(): 17 | data_dic_name_list.append(video_name) 18 | random.shuffle(data_dic_name_list) 19 | print('finish loading') 20 | return data_dic_name_list,data_dic 21 | 22 | 23 | class DINetDataset(Dataset): 24 | def __init__(self,path_json,augment_num,mouth_region_size): 25 | super(DINetDataset, self).__init__() 26 | self.data_dic_name_list,self.data_dic = get_data(path_json,augment_num) 27 | self.mouth_region_size = mouth_region_size 28 | self.radius = mouth_region_size//2 29 | self.radius_1_4 = self.radius//4 30 | self.img_h = self.radius * 3 + self.radius_1_4 31 | self.img_w = self.radius * 2 + self.radius_1_4 * 2 32 | self.length = len(self.data_dic_name_list) 33 | 34 | def __getitem__(self, index): 35 | video_name = self.data_dic_name_list[index] 36 | video_clip_num = len(self.data_dic[video_name]['clip_data_list']) 37 | random_anchor = random.sample(range(video_clip_num), 6) 38 | source_anchor, reference_anchor_list = random_anchor[0],random_anchor[1:] 39 | ## load source image 40 | source_image_path_list = self.data_dic[video_name]['clip_data_list'][source_anchor]['frame_path_list'] 41 | source_random_index = random.sample(range(2, 7), 1)[0] 42 | source_image_data = cv2.imread(source_image_path_list[source_random_index])[:, :, ::-1] 43 | source_image_data = cv2.resize(source_image_data, (self.img_w, self.img_h))/ 255.0 44 | source_image_mask = source_image_data.copy() 45 | source_image_mask[self.radius:self.radius+self.mouth_region_size,self.radius_1_4:self.radius_1_4 +self.mouth_region_size ,:] = 0 46 | 47 | ## load deep speech feature 48 | deepspeech_feature = np.array(self.data_dic[video_name]['clip_data_list'][source_anchor]['deep_speech_list'][source_random_index - 2:source_random_index + 3]) 49 | 50 | ## load reference images 51 | reference_frame_data_list = [] 52 | for reference_anchor in reference_anchor_list: 53 | reference_frame_path_list = self.data_dic[video_name]['clip_data_list'][reference_anchor]['frame_path_list'] 54 | reference_random_index = random.sample(range(9), 1)[0] 55 | reference_frame_path = reference_frame_path_list[reference_random_index] 56 | reference_frame_data = cv2.imread(reference_frame_path)[:, :, ::-1] 57 | reference_frame_data = cv2.resize(reference_frame_data, (self.img_w, self.img_h))/ 255.0 58 | reference_frame_data_list.append(reference_frame_data) 59 | reference_clip_data = np.concatenate(reference_frame_data_list, 2) 60 | 61 | # display the source image and reference images 62 | # display_img = np.concatenate([source_image_data,source_image_mask]+reference_frame_data_list,1) 63 | # cv2.imshow('image display',(display_img[:,:,::-1] * 255).astype(np.uint8)) 64 | # cv2.waitKey(-1) 65 | 66 | # # to tensor 67 | source_image_data = torch.from_numpy(source_image_data).float().permute(2,0,1) 68 | source_image_mask = torch.from_numpy(source_image_mask).float().permute(2,0,1) 69 | reference_clip_data = torch.from_numpy(reference_clip_data).float().permute(2,0,1) 70 | deepspeech_feature = torch.from_numpy(deepspeech_feature).float().permute(1,0) 71 | return source_image_data,source_image_mask, reference_clip_data,deepspeech_feature 72 | 73 | def __len__(self): 74 | return self.length 75 | 76 | 77 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from utils.deep_speech import DeepSpeech 2 | from utils.data_processing import load_landmark_openface,compute_crop_radius 3 | from config.config import DINetInferenceOptions 4 | from models.DINet import DINet 5 | 6 | import numpy as np 7 | import glob 8 | import os 9 | import cv2 10 | import torch 11 | import subprocess 12 | import random 13 | from collections import OrderedDict 14 | 15 | def extract_frames_from_video(video_path,save_dir): 16 | videoCapture = cv2.VideoCapture(video_path) 17 | fps = videoCapture.get(cv2.CAP_PROP_FPS) 18 | if int(fps) != 25: 19 | print('warning: the input video is not 25 fps, it would be better to trans it to 25 fps!') 20 | frames = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT) 21 | frame_height = videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT) 22 | frame_width = videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH) 23 | for i in range(int(frames)): 24 | ret, frame = videoCapture.read() 25 | result_path = os.path.join(save_dir, str(i).zfill(6) + '.jpg') 26 | cv2.imwrite(result_path, frame) 27 | return (int(frame_width),int(frame_height)) 28 | 29 | if __name__ == '__main__': 30 | # load config 31 | opt = DINetInferenceOptions().parse_args() 32 | if not os.path.exists(opt.source_video_path): 33 | raise ('wrong video path : {}'.format(opt.source_video_path)) 34 | ############################################## extract frames from source video ############################################## 35 | print('extracting frames from video: {}'.format(opt.source_video_path)) 36 | video_frame_dir = opt.source_video_path.replace('.mp4', '') 37 | if not os.path.exists(video_frame_dir): 38 | os.mkdir(video_frame_dir) 39 | video_size = extract_frames_from_video(opt.source_video_path,video_frame_dir) 40 | ############################################## extract deep speech feature ############################################## 41 | print('extracting deepspeech feature from : {}'.format(opt.driving_audio_path)) 42 | if not os.path.exists(opt.deepspeech_model_path): 43 | raise ('pls download pretrained model of deepspeech') 44 | DSModel = DeepSpeech(opt.deepspeech_model_path) 45 | if not os.path.exists(opt.driving_audio_path): 46 | raise ('wrong audio path :{}'.format(opt.driving_audio_path)) 47 | ds_feature = DSModel.compute_audio_feature(opt.driving_audio_path) 48 | res_frame_length = ds_feature.shape[0] 49 | ds_feature_padding = np.pad(ds_feature, ((2, 2), (0, 0)), mode='edge') 50 | ############################################## load facial landmark ############################################## 51 | print('loading facial landmarks from : {}'.format(opt.source_openface_landmark_path)) 52 | if not os.path.exists(opt.source_openface_landmark_path): 53 | raise ('wrong facial landmark path :{}'.format(opt.source_openface_landmark_path)) 54 | video_landmark_data = load_landmark_openface(opt.source_openface_landmark_path).astype(np.int) 55 | ############################################## align frame with driving audio ############################################## 56 | print('aligning frames with driving audio') 57 | video_frame_path_list = glob.glob(os.path.join(video_frame_dir, '*.jpg')) 58 | if len(video_frame_path_list) != video_landmark_data.shape[0]: 59 | raise ('video frames are misaligned with detected landmarks') 60 | video_frame_path_list.sort() 61 | video_frame_path_list_cycle = video_frame_path_list + video_frame_path_list[::-1] 62 | video_landmark_data_cycle = np.concatenate([video_landmark_data, np.flip(video_landmark_data, 0)], 0) 63 | video_frame_path_list_cycle_length = len(video_frame_path_list_cycle) 64 | if video_frame_path_list_cycle_length >= res_frame_length: 65 | res_video_frame_path_list = video_frame_path_list_cycle[:res_frame_length] 66 | res_video_landmark_data = video_landmark_data_cycle[:res_frame_length, :, :] 67 | else: 68 | divisor = res_frame_length // video_frame_path_list_cycle_length 69 | remainder = res_frame_length % video_frame_path_list_cycle_length 70 | res_video_frame_path_list = video_frame_path_list_cycle * divisor + video_frame_path_list_cycle[:remainder] 71 | res_video_landmark_data = np.concatenate([video_landmark_data_cycle]* divisor + [video_landmark_data_cycle[:remainder, :, :]],0) 72 | res_video_frame_path_list_pad = [video_frame_path_list_cycle[0]] * 2 \ 73 | + res_video_frame_path_list \ 74 | + [video_frame_path_list_cycle[-1]] * 2 75 | res_video_landmark_data_pad = np.pad(res_video_landmark_data, ((2, 2), (0, 0), (0, 0)), mode='edge') 76 | assert ds_feature_padding.shape[0] == len(res_video_frame_path_list_pad) == res_video_landmark_data_pad.shape[0] 77 | pad_length = ds_feature_padding.shape[0] 78 | 79 | ############################################## randomly select 5 reference images ############################################## 80 | print('selecting five reference images') 81 | ref_img_list = [] 82 | resize_w = int(opt.mouth_region_size + opt.mouth_region_size // 4) 83 | resize_h = int((opt.mouth_region_size // 2) * 3 + opt.mouth_region_size // 8) 84 | ref_index_list = random.sample(range(5, len(res_video_frame_path_list_pad) - 2), 5) 85 | for ref_index in ref_index_list: 86 | crop_flag,crop_radius = compute_crop_radius(video_size,res_video_landmark_data_pad[ref_index - 5:ref_index, :, :]) 87 | if not crop_flag: 88 | raise ('our method can not handle videos with large change of facial size!!') 89 | crop_radius_1_4 = crop_radius // 4 90 | ref_img = cv2.imread(res_video_frame_path_list_pad[ref_index- 3])[:, :, ::-1] 91 | ref_landmark = res_video_landmark_data_pad[ref_index - 3, :, :] 92 | ref_img_crop = ref_img[ 93 | ref_landmark[29, 1] - crop_radius:ref_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, 94 | ref_landmark[33, 0] - crop_radius - crop_radius_1_4:ref_landmark[33, 0] + crop_radius +crop_radius_1_4, 95 | :] 96 | ref_img_crop = cv2.resize(ref_img_crop,(resize_w,resize_h)) 97 | ref_img_crop = ref_img_crop / 255.0 98 | ref_img_list.append(ref_img_crop) 99 | ref_video_frame = np.concatenate(ref_img_list, 2) 100 | ref_img_tensor = torch.from_numpy(ref_video_frame).permute(2, 0, 1).unsqueeze(0).float().cuda() 101 | 102 | ############################################## load pretrained model weight ############################################## 103 | print('loading pretrained model from: {}'.format(opt.pretrained_clip_DINet_path)) 104 | model = DINet(opt.source_channel, opt.ref_channel, opt.audio_channel).cuda() 105 | if not os.path.exists(opt.pretrained_clip_DINet_path): 106 | raise ('wrong path of pretrained model weight: {}'.format(opt.pretrained_clip_DINet_path)) 107 | state_dict = torch.load(opt.pretrained_clip_DINet_path)['state_dict']['net_g'] 108 | new_state_dict = OrderedDict() 109 | for k, v in state_dict.items(): 110 | name = k[7:] # remove module. 111 | new_state_dict[name] = v 112 | model.load_state_dict(new_state_dict) 113 | model.eval() 114 | ############################################## inference frame by frame ############################################## 115 | if not os.path.exists(opt.res_video_dir): 116 | os.mkdir(opt.res_video_dir) 117 | res_video_path = os.path.join(opt.res_video_dir,os.path.basename(opt.source_video_path)[:-4] + '_facial_dubbing.mp4') 118 | if os.path.exists(res_video_path): 119 | os.remove(res_video_path) 120 | res_face_path = res_video_path.replace('_facial_dubbing.mp4', '_synthetic_face.mp4') 121 | if os.path.exists(res_face_path): 122 | os.remove(res_face_path) 123 | videowriter = cv2.VideoWriter(res_video_path, cv2.VideoWriter_fourcc(*'XVID'), 25, video_size) 124 | videowriter_face = cv2.VideoWriter(res_face_path, cv2.VideoWriter_fourcc(*'XVID'), 25, (resize_w, resize_h)) 125 | for clip_end_index in range(5, pad_length, 1): 126 | print('synthesizing {}/{} frame'.format(clip_end_index - 5, pad_length - 5)) 127 | crop_flag, crop_radius = compute_crop_radius(video_size,res_video_landmark_data_pad[clip_end_index - 5:clip_end_index, :, :],random_scale = 1.05) 128 | if not crop_flag: 129 | raise ('our method can not handle videos with large change of facial size!!') 130 | crop_radius_1_4 = crop_radius // 4 131 | frame_data = cv2.imread(res_video_frame_path_list_pad[clip_end_index - 3])[:, :, ::-1] 132 | frame_landmark = res_video_landmark_data_pad[clip_end_index - 3, :, :] 133 | crop_frame_data = frame_data[ 134 | frame_landmark[29, 1] - crop_radius:frame_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4, 135 | frame_landmark[33, 0] - crop_radius - crop_radius_1_4:frame_landmark[33, 0] + crop_radius +crop_radius_1_4, 136 | :] 137 | crop_frame_h,crop_frame_w = crop_frame_data.shape[0],crop_frame_data.shape[1] 138 | crop_frame_data = cv2.resize(crop_frame_data, (resize_w,resize_h)) # [32:224, 32:224, :] 139 | crop_frame_data = crop_frame_data / 255.0 140 | crop_frame_data[opt.mouth_region_size//2:opt.mouth_region_size//2 + opt.mouth_region_size, 141 | opt.mouth_region_size//8:opt.mouth_region_size//8 + opt.mouth_region_size, :] = 0 142 | 143 | crop_frame_tensor = torch.from_numpy(crop_frame_data).float().cuda().permute(2, 0, 1).unsqueeze(0) 144 | deepspeech_tensor = torch.from_numpy(ds_feature_padding[clip_end_index - 5:clip_end_index, :]).permute(1, 0).unsqueeze(0).float().cuda() 145 | with torch.no_grad(): 146 | pre_frame = model(crop_frame_tensor, ref_img_tensor, deepspeech_tensor) 147 | pre_frame = pre_frame.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255 148 | videowriter_face.write(pre_frame[:, :, ::-1].copy().astype(np.uint8)) 149 | pre_frame_resize = cv2.resize(pre_frame, (crop_frame_w,crop_frame_h)) 150 | frame_data[ 151 | frame_landmark[29, 1] - crop_radius: 152 | frame_landmark[29, 1] + crop_radius * 2, 153 | frame_landmark[33, 0] - crop_radius - crop_radius_1_4: 154 | frame_landmark[33, 0] + crop_radius + crop_radius_1_4, 155 | :] = pre_frame_resize[:crop_radius * 3,:,:] 156 | videowriter.write(frame_data[:, :, ::-1]) 157 | videowriter.release() 158 | videowriter_face.release() 159 | video_add_audio_path = res_video_path.replace('.mp4', '_add_audio.mp4') 160 | if os.path.exists(video_add_audio_path): 161 | os.remove(video_add_audio_path) 162 | cmd = 'ffmpeg -i {} -i {} -c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 {}'.format( 163 | res_video_path, 164 | opt.driving_audio_path, 165 | video_add_audio_path) 166 | subprocess.call(cmd, shell=True) 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /models/DINet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import math 5 | import cv2 6 | import numpy as np 7 | from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d 8 | from sync_batchnorm import SynchronizedBatchNorm1d as BatchNorm1d 9 | 10 | def make_coordinate_grid_3d(spatial_size, type): 11 | ''' 12 | generate 3D coordinate grid 13 | ''' 14 | d, h, w = spatial_size 15 | x = torch.arange(w).type(type) 16 | y = torch.arange(h).type(type) 17 | z = torch.arange(d).type(type) 18 | x = (2 * (x / (w - 1)) - 1) 19 | y = (2 * (y / (h - 1)) - 1) 20 | z = (2 * (z / (d - 1)) - 1) 21 | yy = y.view(1,-1, 1).repeat(d,1, w) 22 | xx = x.view(1,1, -1).repeat(d,h, 1) 23 | zz = z.view(-1,1,1).repeat(1,h,w) 24 | meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3)], 3) 25 | return meshed,zz 26 | 27 | class ResBlock1d(nn.Module): 28 | ''' 29 | basic block 30 | ''' 31 | def __init__(self, in_features,out_features, kernel_size, padding): 32 | super(ResBlock1d, self).__init__() 33 | self.in_features = in_features 34 | self.out_features = out_features 35 | self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 36 | padding=padding) 37 | self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 38 | padding=padding) 39 | if out_features != in_features: 40 | self.channel_conv = nn.Conv1d(in_features,out_features,1) 41 | self.norm1 = BatchNorm1d(in_features) 42 | self.norm2 = BatchNorm1d(in_features) 43 | self.relu = nn.ReLU() 44 | def forward(self, x): 45 | out = self.norm1(x) 46 | out = self.relu(out) 47 | out = self.conv1(out) 48 | out = self.norm2(out) 49 | out = self.relu(out) 50 | out = self.conv2(out) 51 | if self.in_features != self.out_features: 52 | out += self.channel_conv(x) 53 | else: 54 | out += x 55 | return out 56 | 57 | class ResBlock2d(nn.Module): 58 | ''' 59 | basic block 60 | ''' 61 | def __init__(self, in_features,out_features, kernel_size, padding): 62 | super(ResBlock2d, self).__init__() 63 | self.in_features = in_features 64 | self.out_features = out_features 65 | self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 66 | padding=padding) 67 | self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 68 | padding=padding) 69 | if out_features != in_features: 70 | self.channel_conv = nn.Conv2d(in_features,out_features,1) 71 | self.norm1 = BatchNorm2d(in_features) 72 | self.norm2 = BatchNorm2d(in_features) 73 | self.relu = nn.ReLU() 74 | def forward(self, x): 75 | out = self.norm1(x) 76 | out = self.relu(out) 77 | out = self.conv1(out) 78 | out = self.norm2(out) 79 | out = self.relu(out) 80 | out = self.conv2(out) 81 | if self.in_features != self.out_features: 82 | out += self.channel_conv(x) 83 | else: 84 | out += x 85 | return out 86 | 87 | class UpBlock2d(nn.Module): 88 | ''' 89 | basic block 90 | ''' 91 | def __init__(self, in_features, out_features, kernel_size=3, padding=1): 92 | super(UpBlock2d, self).__init__() 93 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 94 | padding=padding) 95 | self.norm = BatchNorm2d(out_features) 96 | self.relu = nn.ReLU() 97 | def forward(self, x): 98 | out = F.interpolate(x, scale_factor=2) 99 | out = self.conv(out) 100 | out = self.norm(out) 101 | out = F.relu(out) 102 | return out 103 | 104 | class DownBlock1d(nn.Module): 105 | ''' 106 | basic block 107 | ''' 108 | def __init__(self, in_features, out_features, kernel_size, padding): 109 | super(DownBlock1d, self).__init__() 110 | self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 111 | padding=padding,stride=2) 112 | self.norm = BatchNorm1d(out_features) 113 | self.relu = nn.ReLU() 114 | def forward(self, x): 115 | out = self.conv(x) 116 | out = self.norm(out) 117 | out = self.relu(out) 118 | return out 119 | 120 | class DownBlock2d(nn.Module): 121 | ''' 122 | basic block 123 | ''' 124 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, stride=2): 125 | super(DownBlock2d, self).__init__() 126 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 127 | padding=padding, stride=stride) 128 | self.norm = BatchNorm2d(out_features) 129 | self.relu = nn.ReLU() 130 | def forward(self, x): 131 | out = self.conv(x) 132 | out = self.norm(out) 133 | out = self.relu(out) 134 | return out 135 | 136 | class SameBlock1d(nn.Module): 137 | ''' 138 | basic block 139 | ''' 140 | def __init__(self, in_features, out_features, kernel_size, padding): 141 | super(SameBlock1d, self).__init__() 142 | self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, 143 | kernel_size=kernel_size, padding=padding) 144 | self.norm = BatchNorm1d(out_features) 145 | self.relu = nn.ReLU() 146 | def forward(self, x): 147 | out = self.conv(x) 148 | out = self.norm(out) 149 | out = self.relu(out) 150 | return out 151 | 152 | class SameBlock2d(nn.Module): 153 | ''' 154 | basic block 155 | ''' 156 | def __init__(self, in_features, out_features, kernel_size=3, padding=1): 157 | super(SameBlock2d, self).__init__() 158 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, 159 | kernel_size=kernel_size, padding=padding) 160 | self.norm = BatchNorm2d(out_features) 161 | self.relu = nn.ReLU() 162 | def forward(self, x): 163 | out = self.conv(x) 164 | out = self.norm(out) 165 | out = self.relu(out) 166 | return out 167 | 168 | class AdaAT(nn.Module): 169 | ''' 170 | AdaAT operator 171 | ''' 172 | def __init__(self, para_ch,feature_ch): 173 | super(AdaAT, self).__init__() 174 | self.para_ch = para_ch 175 | self.feature_ch = feature_ch 176 | self.commn_linear = nn.Sequential( 177 | nn.Linear(para_ch, para_ch), 178 | nn.ReLU() 179 | ) 180 | self.scale = nn.Sequential( 181 | nn.Linear(para_ch, feature_ch), 182 | nn.Sigmoid() 183 | ) 184 | self.rotation = nn.Sequential( 185 | nn.Linear(para_ch, feature_ch), 186 | nn.Tanh() 187 | ) 188 | self.translation = nn.Sequential( 189 | nn.Linear(para_ch, 2 * feature_ch), 190 | nn.Tanh() 191 | ) 192 | self.tanh = nn.Tanh() 193 | self.sigmoid = nn.Sigmoid() 194 | 195 | def forward(self, feature_map,para_code): 196 | batch,d, h, w = feature_map.size(0), feature_map.size(1), feature_map.size(2), feature_map.size(3) 197 | para_code = self.commn_linear(para_code) 198 | scale = self.scale(para_code).unsqueeze(-1) * 2 199 | angle = self.rotation(para_code).unsqueeze(-1) * 3.14159# 200 | rotation_matrix = torch.cat([torch.cos(angle), -torch.sin(angle), torch.sin(angle), torch.cos(angle)], -1) 201 | rotation_matrix = rotation_matrix.view(batch, self.feature_ch, 2, 2) 202 | translation = self.translation(para_code).view(batch, self.feature_ch, 2) 203 | grid_xy, grid_z = make_coordinate_grid_3d((d, h, w), feature_map.type()) 204 | grid_xy = grid_xy.unsqueeze(0).repeat(batch, 1, 1, 1, 1) 205 | grid_z = grid_z.unsqueeze(0).repeat(batch, 1, 1, 1) 206 | scale = scale.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1) 207 | rotation_matrix = rotation_matrix.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1, 1) 208 | translation = translation.unsqueeze(2).unsqueeze(3).repeat(1, 1, h, w, 1) 209 | trans_grid = torch.matmul(rotation_matrix, grid_xy.unsqueeze(-1)).squeeze(-1) * scale + translation 210 | full_grid = torch.cat([trans_grid, grid_z.unsqueeze(-1)], -1) 211 | trans_feature = F.grid_sample(feature_map.unsqueeze(1), full_grid).squeeze(1) 212 | return trans_feature 213 | 214 | class DINet(nn.Module): 215 | def __init__(self, source_channel,ref_channel,audio_channel): 216 | super(DINet, self).__init__() 217 | self.source_in_conv = nn.Sequential( 218 | SameBlock2d(source_channel,64,kernel_size=7, padding=3), 219 | DownBlock2d(64, 128, kernel_size=3, padding=1), 220 | DownBlock2d(128,256,kernel_size=3, padding=1) 221 | ) 222 | self.ref_in_conv = nn.Sequential( 223 | SameBlock2d(ref_channel, 64, kernel_size=7, padding=3), 224 | DownBlock2d(64, 128, kernel_size=3, padding=1), 225 | DownBlock2d(128, 256, kernel_size=3, padding=1), 226 | ) 227 | self.trans_conv = nn.Sequential( 228 | # 20 →10 229 | SameBlock2d(512, 128, kernel_size=3, padding=1), 230 | SameBlock2d(128, 128, kernel_size=11, padding=5), 231 | SameBlock2d(128, 128, kernel_size=11, padding=5), 232 | DownBlock2d(128, 128, kernel_size=3, padding=1), 233 | # 10 →5 234 | SameBlock2d(128, 128, kernel_size=7, padding=3), 235 | SameBlock2d(128, 128, kernel_size=7, padding=3), 236 | DownBlock2d(128, 128, kernel_size=3, padding=1), 237 | # 5 →3 238 | SameBlock2d(128, 128, kernel_size=3, padding=1), 239 | DownBlock2d(128, 128, kernel_size=3, padding=1), 240 | # 3 →2 241 | SameBlock2d(128, 128, kernel_size=3, padding=1), 242 | DownBlock2d(128, 128, kernel_size=3, padding=1), 243 | 244 | ) 245 | self.audio_encoder = nn.Sequential( 246 | SameBlock1d(audio_channel, 128, kernel_size=5, padding=2), 247 | ResBlock1d(128, 128, 3, 1), 248 | DownBlock1d(128, 128, 3, 1), 249 | ResBlock1d(128, 128, 3, 1), 250 | DownBlock1d(128, 128, 3, 1), 251 | SameBlock1d(128, 128, kernel_size=3, padding=1) 252 | ) 253 | 254 | appearance_conv_list = [] 255 | for i in range(2): 256 | appearance_conv_list.append( 257 | nn.Sequential( 258 | ResBlock2d(256, 256, 3, 1), 259 | ResBlock2d(256, 256, 3, 1), 260 | ResBlock2d(256, 256, 3, 1), 261 | ResBlock2d(256, 256, 3, 1), 262 | ) 263 | ) 264 | self.appearance_conv_list = nn.ModuleList(appearance_conv_list) 265 | self.adaAT = AdaAT(256, 256) 266 | self.out_conv = nn.Sequential( 267 | SameBlock2d(512, 128, kernel_size=3, padding=1), 268 | UpBlock2d(128,128,kernel_size=3, padding=1), 269 | ResBlock2d(128, 128, 3, 1), 270 | UpBlock2d(128, 128, kernel_size=3, padding=1), 271 | nn.Conv2d(128, 3, kernel_size=(7, 7), padding=(3, 3)), 272 | nn.Sigmoid() 273 | ) 274 | self.global_avg2d = nn.AdaptiveAvgPool2d(1) 275 | self.global_avg1d = nn.AdaptiveAvgPool1d(1) 276 | def forward(self, source_img,ref_img,audio_feature): 277 | ## source image encoder 278 | source_in_feature = self.source_in_conv(source_img) 279 | ## reference image encoder 280 | ref_in_feature = self.ref_in_conv(ref_img) 281 | ## alignment encoder 282 | img_para = self.trans_conv(torch.cat([source_in_feature,ref_in_feature],1)) 283 | img_para = self.global_avg2d(img_para).squeeze(3).squeeze(2) 284 | ## audio encoder 285 | audio_para = self.audio_encoder(audio_feature) 286 | audio_para = self.global_avg1d(audio_para).squeeze(2) 287 | ## concat alignment feature and audio feature 288 | trans_para = torch.cat([img_para,audio_para],1) 289 | ## use AdaAT do spatial deformation on reference feature maps 290 | ref_trans_feature = self.appearance_conv_list[0](ref_in_feature) 291 | ref_trans_feature = self.adaAT(ref_trans_feature, trans_para) 292 | ref_trans_feature = self.appearance_conv_list[1](ref_trans_feature) 293 | ## feature decoder 294 | merge_feature = torch.cat([source_in_feature,ref_trans_feature],1) 295 | out = self.out_conv(merge_feature) 296 | return out 297 | 298 | 299 | 300 | -------------------------------------------------------------------------------- /models/Discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | class DownBlock2d(nn.Module): 5 | def __init__(self, in_features, out_features, kernel_size=4, pool=False): 6 | super(DownBlock2d, self).__init__() 7 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) 8 | self.pool = pool 9 | def forward(self, x): 10 | out = x 11 | out = self.conv(out) 12 | out = F.leaky_relu(out, 0.2) 13 | if self.pool: 14 | out = F.avg_pool2d(out, (2, 2)) 15 | return out 16 | 17 | 18 | class Discriminator(nn.Module): 19 | """ 20 | Discriminator for GAN loss 21 | """ 22 | def __init__(self, num_channels, block_expansion=64, num_blocks=4, max_features=512): 23 | super(Discriminator, self).__init__() 24 | down_blocks = [] 25 | for i in range(num_blocks): 26 | down_blocks.append( 27 | DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)), 28 | min(max_features, block_expansion * (2 ** (i + 1))), 29 | kernel_size=4, pool=(i != num_blocks - 1))) 30 | self.down_blocks = nn.ModuleList(down_blocks) 31 | self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) 32 | def forward(self, x): 33 | feature_maps = [] 34 | out = x 35 | for down_block in self.down_blocks: 36 | feature_maps.append(down_block(out)) 37 | out = feature_maps[-1] 38 | out = self.conv(out) 39 | return feature_maps, out 40 | -------------------------------------------------------------------------------- /models/Syncnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | class ResBlock1d(nn.Module): 4 | ''' 5 | basic block (no BN) 6 | ''' 7 | def __init__(self, in_features,out_features, kernel_size, padding): 8 | super(ResBlock1d, self).__init__() 9 | self.in_features = in_features 10 | self.out_features = out_features 11 | self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 12 | padding=padding) 13 | self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 14 | padding=padding) 15 | if out_features != in_features: 16 | self.channel_conv = nn.Conv1d(in_features,out_features,1) 17 | self.relu = nn.ReLU() 18 | def forward(self, x): 19 | out = self.relu(x) 20 | out = self.conv1(out) 21 | out = self.relu(out) 22 | out = self.conv2(out) 23 | if self.in_features != self.out_features: 24 | out += self.channel_conv(x) 25 | else: 26 | out += x 27 | return out 28 | 29 | class ResBlock2d(nn.Module): 30 | ''' 31 | basic block (no BN) 32 | ''' 33 | def __init__(self, in_features,out_features, kernel_size, padding): 34 | super(ResBlock2d, self).__init__() 35 | self.in_features = in_features 36 | self.out_features = out_features 37 | self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 38 | padding=padding) 39 | self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 40 | padding=padding) 41 | if out_features != in_features: 42 | self.channel_conv = nn.Conv2d(in_features,out_features,1) 43 | self.relu = nn.ReLU() 44 | def forward(self, x): 45 | out = self.relu(x) 46 | out = self.conv1(out) 47 | out = self.relu(out) 48 | out = self.conv2(out) 49 | if self.in_features != self.out_features: 50 | out += self.channel_conv(x) 51 | else: 52 | out += x 53 | return out 54 | 55 | class DownBlock1d(nn.Module): 56 | ''' 57 | basic block (no BN) 58 | ''' 59 | def __init__(self, in_features, out_features, kernel_size, padding): 60 | super(DownBlock1d, self).__init__() 61 | self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 62 | padding=padding,stride=2) 63 | self.relu = nn.ReLU() 64 | def forward(self, x): 65 | out = self.conv(x) 66 | out = self.relu(out) 67 | return out 68 | 69 | class DownBlock2d(nn.Module): 70 | ''' 71 | basic block (no BN) 72 | ''' 73 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 74 | super(DownBlock2d, self).__init__() 75 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 76 | padding=padding, groups=groups) 77 | self.pool = nn.AvgPool2d(kernel_size=(2, 2)) 78 | self.relu = nn.ReLU() 79 | def forward(self, x): 80 | out = self.conv(x) 81 | out = self.relu(out) 82 | out = self.pool(out) 83 | return out 84 | 85 | class SameBlock1d(nn.Module): 86 | ''' 87 | basic block (no BN) 88 | ''' 89 | def __init__(self, in_features, out_features, kernel_size, padding): 90 | super(SameBlock1d, self).__init__() 91 | self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, 92 | kernel_size=kernel_size, padding=padding) 93 | self.relu = nn.ReLU() 94 | def forward(self, x): 95 | out = self.conv(x) 96 | out = self.relu(out) 97 | return out 98 | 99 | class SameBlock2d(nn.Module): 100 | ''' 101 | basic block (no BN) 102 | ''' 103 | def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): 104 | super(SameBlock2d, self).__init__() 105 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, 106 | kernel_size=kernel_size, padding=padding, groups=groups) 107 | self.relu = nn.ReLU() 108 | def forward(self, x): 109 | out = self.conv(x) 110 | out = self.relu(out) 111 | return out 112 | 113 | class FaceEncoder(nn.Module): 114 | ''' 115 | image encoder 116 | ''' 117 | def __init__(self, in_channel, out_dim): 118 | super(FaceEncoder, self).__init__() 119 | self.in_channel = in_channel 120 | self.out_dim = out_dim 121 | self.face_conv = nn.Sequential( 122 | SameBlock2d(in_channel,64,kernel_size=7,padding=3), 123 | # # 64 → 32 124 | ResBlock2d(64, 64, kernel_size=3, padding=1), 125 | DownBlock2d(64,64,3,1), 126 | SameBlock2d(64, 128), 127 | # 32 → 16 128 | ResBlock2d(128, 128, kernel_size=3, padding=1), 129 | DownBlock2d(128,128,3,1), 130 | SameBlock2d(128, 128), 131 | # 16 → 8 132 | ResBlock2d(128, 128, kernel_size=3, padding=1), 133 | DownBlock2d(128,128,3,1), 134 | SameBlock2d(128, 128), 135 | # 8 → 4 136 | ResBlock2d(128, 128, kernel_size=3, padding=1), 137 | DownBlock2d(128,128,3,1), 138 | SameBlock2d(128, 128), 139 | # 4 → 2 140 | ResBlock2d(128, 128, kernel_size=3, padding=1), 141 | DownBlock2d(128,128,3,1), 142 | SameBlock2d(128,out_dim,kernel_size=1,padding=0) 143 | ) 144 | def forward(self, x): 145 | ## b x c x h x w 146 | out = self.face_conv(x) 147 | return out 148 | 149 | class AudioEncoder(nn.Module): 150 | ''' 151 | audio encoder 152 | ''' 153 | def __init__(self, in_channel, out_dim): 154 | super(AudioEncoder, self).__init__() 155 | self.in_channel = in_channel 156 | self.out_dim = out_dim 157 | self.audio_conv = nn.Sequential( 158 | SameBlock1d(in_channel,128,kernel_size=7,padding=3), 159 | ResBlock1d(128, 128, 3, 1), 160 | # 9-5 161 | DownBlock1d(128, 128, 3, 1), 162 | ResBlock1d(128, 128, 3, 1), 163 | # 5 -3 164 | DownBlock1d(128, 128, 3, 1), 165 | ResBlock1d(128, 128, 3, 1), 166 | # 3-2 167 | DownBlock1d(128, 128, 3, 1), 168 | SameBlock1d(128,out_dim,kernel_size=3,padding=1) 169 | ) 170 | self.global_avg = nn.AdaptiveAvgPool1d(1) 171 | def forward(self, x): 172 | ## b x c x t 173 | out = self.audio_conv(x) 174 | return self.global_avg(out).squeeze(2) 175 | 176 | class SyncNet(nn.Module): 177 | ''' 178 | syncnet 179 | ''' 180 | def __init__(self, in_channel_image,in_channel_audio, out_dim): 181 | super(SyncNet, self).__init__() 182 | self.in_channel_image = in_channel_image 183 | self.in_channel_audio = in_channel_audio 184 | self.out_dim = out_dim 185 | self.face_encoder = FaceEncoder(in_channel_image,out_dim) 186 | self.audio_encoder = AudioEncoder(in_channel_audio,out_dim) 187 | self.merge_encoder = nn.Sequential( 188 | nn.Conv2d(out_dim * 2, out_dim, kernel_size=3, padding=1), 189 | nn.LeakyReLU(0.2), 190 | nn.Conv2d(out_dim, 1, kernel_size=3, padding=1), 191 | ) 192 | def forward(self, image,audio): 193 | image_embedding = self.face_encoder(image) 194 | audio_embedding = self.audio_encoder(audio).unsqueeze(2).unsqueeze(3).repeat(1,1,image_embedding.size(2),image_embedding.size(3)) 195 | concat_embedding = torch.cat([image_embedding,audio_embedding],1) 196 | out_score = self.merge_encoder(concat_embedding) 197 | return out_score 198 | 199 | class SyncNetPerception(nn.Module): 200 | ''' 201 | use syncnet to compute perception loss 202 | ''' 203 | def __init__(self,pretrain_path): 204 | super(SyncNetPerception, self).__init__() 205 | self.model = SyncNet(15,29,128) 206 | print('load lip sync model : {}'.format(pretrain_path)) 207 | self.model.load_state_dict(torch.load(pretrain_path)['state_dict']['net']) 208 | for param in self.model.parameters(): 209 | param.requires_grad = False 210 | self.model.eval() 211 | 212 | def forward(self, image,audio): 213 | score = self.model(image,audio) 214 | return score 215 | 216 | -------------------------------------------------------------------------------- /models/VGG19.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | import numpy as np 4 | 5 | 6 | class Vgg19(torch.nn.Module): 7 | """ 8 | Vgg19 network for perceptual loss 9 | """ 10 | def __init__(self, requires_grad=False): 11 | super(Vgg19, self).__init__() 12 | vgg_model = models.vgg19(pretrained=True) 13 | vgg_pretrained_features = vgg_model.features 14 | self.slice1 = torch.nn.Sequential() 15 | self.slice2 = torch.nn.Sequential() 16 | self.slice3 = torch.nn.Sequential() 17 | self.slice4 = torch.nn.Sequential() 18 | self.slice5 = torch.nn.Sequential() 19 | for x in range(2): 20 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(2, 7): 22 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(7, 12): 24 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(12, 21): 26 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 27 | for x in range(21, 30): 28 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 29 | 30 | self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), 31 | requires_grad=False) 32 | self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), 33 | requires_grad=False) 34 | 35 | if not requires_grad: 36 | for param in self.parameters(): 37 | param.requires_grad = False 38 | 39 | def forward(self, X): 40 | X = (X - self.mean) / self.std 41 | h_relu1 = self.slice1(X) 42 | h_relu2 = self.slice2(h_relu1) 43 | h_relu3 = self.slice3(h_relu2) 44 | h_relu4 = self.slice4(h_relu3) 45 | h_relu5 = self.slice5(h_relu4) 46 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 47 | return out 48 | -------------------------------------------------------------------------------- /models/old/Syncnet_BN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class ResBlock1d(nn.Module): 5 | ''' 6 | basic block (BN) 7 | ''' 8 | def __init__(self, in_features,out_features, kernel_size, padding): 9 | super(ResBlock1d, self).__init__() 10 | self.in_features = in_features 11 | self.out_features = out_features 12 | self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 13 | padding=padding) 14 | self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 15 | padding=padding) 16 | if out_features != in_features: 17 | self.channel_conv = nn.Conv1d(in_features,out_features,1) 18 | self.norm1 = nn.BatchNorm1d(in_features) 19 | self.norm2 = nn.BatchNorm1d(in_features) 20 | self.relu = nn.ReLU() 21 | def forward(self, x): 22 | out = self.norm1(x) 23 | out = self.relu(out) 24 | out = self.conv1(out) 25 | out = self.norm2(out) 26 | out = self.relu(out) 27 | out = self.conv2(out) 28 | if self.in_features != self.out_features: 29 | out += self.channel_conv(x) 30 | else: 31 | out += x 32 | return out 33 | 34 | class ResBlock2d(nn.Module): 35 | ''' 36 | basic block (BN) 37 | ''' 38 | def __init__(self, in_features,out_features, kernel_size, padding): 39 | super(ResBlock2d, self).__init__() 40 | self.in_features = in_features 41 | self.out_features = out_features 42 | self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 43 | padding=padding) 44 | self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 45 | padding=padding) 46 | if out_features != in_features: 47 | self.channel_conv = nn.Conv2d(in_features,out_features,1) 48 | self.norm1 = nn.BatchNorm2d(in_features) 49 | self.norm2 = nn.BatchNorm2d(in_features) 50 | self.relu = nn.ReLU() 51 | def forward(self, x): 52 | out = self.norm1(x) 53 | out = self.relu(out) 54 | out = self.conv1(out) 55 | out = self.norm2(out) 56 | out = self.relu(out) 57 | out = self.conv2(out) 58 | if self.in_features != self.out_features: 59 | out += self.channel_conv(x) 60 | else: 61 | out += x 62 | return out 63 | 64 | class DownBlock1d(nn.Module): 65 | ''' 66 | basic block (BN) 67 | ''' 68 | def __init__(self, in_features, out_features, kernel_size, padding): 69 | super(DownBlock1d, self).__init__() 70 | self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 71 | padding=padding,stride=2) 72 | self.norm = nn.BatchNorm1d(out_features) 73 | self.relu = nn.ReLU() 74 | def forward(self, x): 75 | out = self.conv(x) 76 | out = self.norm(out) 77 | out = self.relu(out) 78 | return out 79 | 80 | class DownBlock2d(nn.Module): 81 | ''' 82 | basic block (BN) 83 | ''' 84 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 85 | super(DownBlock2d, self).__init__() 86 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 87 | padding=padding, groups=groups) 88 | self.norm = nn.BatchNorm2d(out_features) 89 | self.pool = nn.AvgPool2d(kernel_size=(2, 2)) 90 | self.relu = nn.ReLU() 91 | def forward(self, x): 92 | out = self.conv(x) 93 | out = self.norm(out) 94 | out = self.relu(out) 95 | out = self.pool(out) 96 | return out 97 | 98 | class SameBlock1d(nn.Module): 99 | ''' 100 | basic block (BN) 101 | ''' 102 | def __init__(self, in_features, out_features, kernel_size, padding): 103 | super(SameBlock1d, self).__init__() 104 | self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, 105 | kernel_size=kernel_size, padding=padding) 106 | self.norm = nn.BatchNorm1d(out_features) 107 | self.relu = nn.ReLU() 108 | def forward(self, x): 109 | out = self.conv(x) 110 | out = self.norm(out) 111 | out = self.relu(out) 112 | return out 113 | 114 | class SameBlock2d(nn.Module): 115 | ''' 116 | basic block (BN) 117 | ''' 118 | def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): 119 | super(SameBlock2d, self).__init__() 120 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, 121 | kernel_size=kernel_size, padding=padding, groups=groups) 122 | self.norm = nn.BatchNorm2d(out_features) 123 | self.relu = nn.ReLU() 124 | def forward(self, x): 125 | out = self.conv(x) 126 | out = self.norm(out) 127 | out = self.relu(out) 128 | return out 129 | 130 | 131 | class FaceEncoder(nn.Module): 132 | ''' 133 | image encoder 134 | ''' 135 | def __init__(self, in_channel, out_dim): 136 | super(FaceEncoder, self).__init__() 137 | self.in_channel = in_channel 138 | self.out_dim = out_dim 139 | self.face_conv = nn.Sequential( 140 | SameBlock2d(in_channel,64,kernel_size=7,padding=3), 141 | # # 64 → 32 142 | ResBlock2d(64, 64, kernel_size=3, padding=1), 143 | DownBlock2d(64,64,3,1), 144 | SameBlock2d(64, 128), 145 | # 32 → 16 146 | ResBlock2d(128, 128, kernel_size=3, padding=1), 147 | DownBlock2d(128,128,3,1), 148 | SameBlock2d(128, 128), 149 | # 16 → 8 150 | ResBlock2d(128, 128, kernel_size=3, padding=1), 151 | DownBlock2d(128,128,3,1), 152 | SameBlock2d(128, 128), 153 | # 8 → 4 154 | ResBlock2d(128, 128, kernel_size=3, padding=1), 155 | DownBlock2d(128,128,3,1), 156 | SameBlock2d(128, 128), 157 | # 4 → 2 158 | ResBlock2d(128, 128, kernel_size=3, padding=1), 159 | DownBlock2d(128,128,3,1), 160 | SameBlock2d(128,out_dim,kernel_size=1,padding=0) 161 | ) 162 | def forward(self, x): 163 | ## b x c x h x w 164 | out = self.face_conv(x) 165 | return out 166 | 167 | class AudioEncoder(nn.Module): 168 | ''' 169 | audio encoder 170 | ''' 171 | def __init__(self, in_channel, out_dim): 172 | super(AudioEncoder, self).__init__() 173 | self.in_channel = in_channel 174 | self.out_dim = out_dim 175 | self.audio_conv = nn.Sequential( 176 | SameBlock1d(in_channel,128,kernel_size=7,padding=3), 177 | ResBlock1d(128, 128, 3, 1), 178 | # 9-5 179 | DownBlock1d(128, 128, 3, 1), 180 | ResBlock1d(128, 128, 3, 1), 181 | # 5 -3 182 | DownBlock1d(128, 128, 3, 1), 183 | ResBlock1d(128, 128, 3, 1), 184 | # 3-2 185 | DownBlock1d(128, 128, 3, 1), 186 | SameBlock1d(128,out_dim,kernel_size=3,padding=1) 187 | ) 188 | self.global_avg = nn.AdaptiveAvgPool1d(1) 189 | def forward(self, x): 190 | ## b x c x t 191 | out = self.audio_conv(x) 192 | return self.global_avg(out).squeeze(2) 193 | 194 | class SyncNet(nn.Module): 195 | def __init__(self, in_channel_image,in_channel_audio, out_dim): 196 | super(SyncNet, self).__init__() 197 | self.in_channel_image = in_channel_image 198 | self.in_channel_audio = in_channel_audio 199 | self.out_dim = out_dim 200 | self.face_encoder = FaceEncoder(in_channel_image,out_dim) 201 | self.audio_encoder = AudioEncoder(in_channel_audio,out_dim) 202 | self.merge_encoder = nn.Sequential( 203 | nn.Conv2d(out_dim * 2, out_dim, kernel_size=3, padding=1), 204 | nn.LeakyReLU(0.2), 205 | nn.Conv2d(out_dim, 1, kernel_size=3, padding=1), 206 | ) 207 | def forward(self, image,audio): 208 | image_embedding = self.face_encoder(image) 209 | audio_embedding = self.audio_encoder(audio).unsqueeze(2).unsqueeze(3).repeat(1,1,image_embedding.size(2),image_embedding.size(3)) 210 | concat_embedding = torch.cat([image_embedding,audio_embedding],1) 211 | out_score = self.merge_encoder(concat_embedding) 212 | return out_score 213 | 214 | -------------------------------------------------------------------------------- /models/old/Syncnet_halfBN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | class ResBlock1d(nn.Module): 6 | ''' 7 | basic block (no BN) 8 | ''' 9 | def __init__(self, in_features,out_features, kernel_size, padding): 10 | super(ResBlock1d, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 14 | padding=padding) 15 | self.conv2 = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 16 | padding=padding) 17 | if out_features != in_features: 18 | self.channel_conv = nn.Conv1d(in_features,out_features,1) 19 | self.relu = nn.ReLU() 20 | def forward(self, x): 21 | out = self.relu(x) 22 | out = self.conv1(out) 23 | out = self.relu(out) 24 | out = self.conv2(out) 25 | if self.in_features != self.out_features: 26 | out += self.channel_conv(x) 27 | else: 28 | out += x 29 | return out 30 | 31 | class ResBlock2d(nn.Module): 32 | ''' 33 | basic block (BN) 34 | ''' 35 | def __init__(self, in_features,out_features, kernel_size, padding): 36 | super(ResBlock2d, self).__init__() 37 | self.in_features = in_features 38 | self.out_features = out_features 39 | self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 40 | padding=padding) 41 | self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 42 | padding=padding) 43 | if out_features != in_features: 44 | self.channel_conv = nn.Conv2d(in_features,out_features,1) 45 | self.norm1 = nn.BatchNorm2d(in_features) 46 | self.norm2 = nn.BatchNorm2d(in_features) 47 | self.relu = nn.ReLU() 48 | def forward(self, x): 49 | out = self.norm1(x) 50 | out = self.relu(out) 51 | out = self.conv1(out) 52 | out = self.norm2(out) 53 | out = self.relu(out) 54 | out = self.conv2(out) 55 | if self.in_features != self.out_features: 56 | out += self.channel_conv(x) 57 | else: 58 | out += x 59 | return out 60 | 61 | class DownBlock1d(nn.Module): 62 | ''' 63 | basic block (no BN) 64 | ''' 65 | def __init__(self, in_features, out_features, kernel_size, padding): 66 | super(DownBlock1d, self).__init__() 67 | self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 68 | padding=padding,stride=2) 69 | self.relu = nn.ReLU() 70 | def forward(self, x): 71 | out = self.conv(x) 72 | out = self.relu(out) 73 | return out 74 | 75 | class DownBlock2d(nn.Module): 76 | ''' 77 | basic block (BN) 78 | ''' 79 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 80 | super(DownBlock2d, self).__init__() 81 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 82 | padding=padding, groups=groups) 83 | self.norm = nn.BatchNorm2d(out_features) 84 | self.pool = nn.AvgPool2d(kernel_size=(2, 2)) 85 | self.relu = nn.ReLU() 86 | def forward(self, x): 87 | out = self.conv(x) 88 | out = self.norm(out) 89 | out = self.relu(out) 90 | out = self.pool(out) 91 | return out 92 | 93 | class SameBlock1d(nn.Module): 94 | ''' 95 | basic block (no BN) 96 | ''' 97 | def __init__(self, in_features, out_features, kernel_size, padding): 98 | super(SameBlock1d, self).__init__() 99 | self.conv = nn.Conv1d(in_channels=in_features, out_channels=out_features, 100 | kernel_size=kernel_size, padding=padding) 101 | self.relu = nn.ReLU() 102 | def forward(self, x): 103 | out = self.conv(x) 104 | out = self.relu(out) 105 | return out 106 | 107 | class SameBlock2d(nn.Module): 108 | ''' 109 | basic block (BN) 110 | ''' 111 | def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): 112 | super(SameBlock2d, self).__init__() 113 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, 114 | kernel_size=kernel_size, padding=padding, groups=groups) 115 | self.norm = nn.BatchNorm2d(out_features) 116 | self.relu = nn.ReLU() 117 | def forward(self, x): 118 | out = self.conv(x) 119 | out = self.norm(out) 120 | out = self.relu(out) 121 | return out 122 | 123 | 124 | class FaceEncoder(nn.Module): 125 | def __init__(self, in_channel, out_dim): 126 | super(FaceEncoder, self).__init__() 127 | self.in_channel = in_channel 128 | self.out_dim = out_dim 129 | self.face_conv = nn.Sequential( 130 | SameBlock2d(in_channel,64,kernel_size=7,padding=3), 131 | # # 64 → 32 132 | ResBlock2d(64, 64, kernel_size=3, padding=1), 133 | DownBlock2d(64,64,3,1), 134 | SameBlock2d(64, 128), 135 | # 32 → 16 136 | ResBlock2d(128, 128, kernel_size=3, padding=1), 137 | DownBlock2d(128,128,3,1), 138 | SameBlock2d(128, 128), 139 | # 16 → 8 140 | ResBlock2d(128, 128, kernel_size=3, padding=1), 141 | DownBlock2d(128,128,3,1), 142 | SameBlock2d(128, 128), 143 | # 8 → 4 144 | ResBlock2d(128, 128, kernel_size=3, padding=1), 145 | DownBlock2d(128,128,3,1), 146 | SameBlock2d(128, 128), 147 | # 4 → 2 148 | ResBlock2d(128, 128, kernel_size=3, padding=1), 149 | DownBlock2d(128,128,3,1), 150 | SameBlock2d(128,out_dim,kernel_size=1,padding=0) 151 | ) 152 | def forward(self, x): 153 | ## b x c x h x w 154 | out = self.face_conv(x) 155 | return out 156 | 157 | class AudioEncoder(nn.Module): 158 | def __init__(self, in_channel, out_dim): 159 | super(AudioEncoder, self).__init__() 160 | self.in_channel = in_channel 161 | self.out_dim = out_dim 162 | self.audio_conv = nn.Sequential( 163 | SameBlock1d(in_channel,128,kernel_size=7,padding=3), 164 | ResBlock1d(128, 128, 3, 1), 165 | # 9-5 166 | DownBlock1d(128, 128, 3, 1), 167 | ResBlock1d(128, 128, 3, 1), 168 | # 5 -3 169 | DownBlock1d(128, 128, 3, 1), 170 | ResBlock1d(128, 128, 3, 1), 171 | # 3-2 172 | DownBlock1d(128, 128, 3, 1), 173 | SameBlock1d(128,out_dim,kernel_size=3,padding=1) 174 | ) 175 | self.global_avg = nn.AdaptiveAvgPool1d(1) 176 | def forward(self, x): 177 | ## b x c x t 178 | out = self.audio_conv(x) 179 | return self.global_avg(out).squeeze(2) 180 | 181 | class SyncNet(nn.Module): 182 | def __init__(self, in_channel_image,in_channel_audio, out_dim): 183 | super(SyncNet, self).__init__() 184 | self.in_channel_image = in_channel_image 185 | self.in_channel_audio = in_channel_audio 186 | self.out_dim = out_dim 187 | self.face_encoder = FaceEncoder(in_channel_image,out_dim) 188 | self.audio_encoder = AudioEncoder(in_channel_audio,out_dim) 189 | self.merge_encoder = nn.Sequential( 190 | nn.Conv2d(out_dim * 2, out_dim, kernel_size=3, padding=1), 191 | nn.LeakyReLU(0.2), 192 | nn.Conv2d(out_dim, 1, kernel_size=3, padding=1), 193 | ) 194 | def forward(self, image,audio): 195 | image_embedding = self.face_encoder(image) 196 | audio_embedding = self.audio_encoder(audio).unsqueeze(2).unsqueeze(3).repeat(1,1,image_embedding.size(2),image_embedding.size(3)) 197 | concat_embedding = torch.cat([image_embedding,audio_embedding],1) 198 | out_score = self.merge_encoder(concat_embedding) 199 | return out_score 200 | 201 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv_python == 4.6.0.66 2 | numpy == 1.16.6 3 | python_speech_features == 0.6 4 | resampy == 0.2.2 5 | scipy == 1.5.4 6 | tensorflow == 1.15.2 7 | torch == 1.7.1+cu101 8 | torchvision == 0.8.2+cu101 9 | -------------------------------------------------------------------------------- /sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import set_sbn_eps_mode 12 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 13 | from .batchnorm import patch_sync_batchnorm, convert_model 14 | from .replicate import DataParallelWithCallback, patch_replication_callback 15 | -------------------------------------------------------------------------------- /sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | import contextlib 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from torch.nn.modules.batchnorm import _BatchNorm 18 | 19 | try: 20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 21 | except ImportError: 22 | ReduceAddCoalesced = Broadcast = None 23 | 24 | try: 25 | from jactorch.parallel.comm import SyncMaster 26 | from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback 27 | except ImportError: 28 | from .comm import SyncMaster 29 | from .replicate import DataParallelWithCallback 30 | 31 | __all__ = [ 32 | 'set_sbn_eps_mode', 33 | 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d', 34 | 'patch_sync_batchnorm', 'convert_model' 35 | ] 36 | 37 | 38 | SBN_EPS_MODE = 'clamp' 39 | 40 | 41 | def set_sbn_eps_mode(mode): 42 | global SBN_EPS_MODE 43 | assert mode in ('clamp', 'plus') 44 | SBN_EPS_MODE = mode 45 | 46 | 47 | def _sum_ft(tensor): 48 | """sum over the first and last dimention""" 49 | return tensor.sum(dim=0).sum(dim=-1) 50 | 51 | 52 | def _unsqueeze_ft(tensor): 53 | """add new dimensions at the front and the tail""" 54 | return tensor.unsqueeze(0).unsqueeze(-1) 55 | 56 | 57 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 58 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 59 | 60 | 61 | class _SynchronizedBatchNorm(_BatchNorm): 62 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): 63 | assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.' 64 | 65 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, 66 | track_running_stats=track_running_stats) 67 | 68 | if not self.track_running_stats: 69 | import warnings 70 | warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.') 71 | 72 | self._sync_master = SyncMaster(self._data_parallel_master) 73 | 74 | self._is_parallel = False 75 | self._parallel_id = None 76 | self._slave_pipe = None 77 | 78 | def forward(self, input): 79 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 80 | if not (self._is_parallel and self.training): 81 | return F.batch_norm( 82 | input, self.running_mean, self.running_var, self.weight, self.bias, 83 | self.training, self.momentum, self.eps) 84 | 85 | # Resize the input to (B, C, -1). 86 | input_shape = input.size() 87 | assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features) 88 | input = input.view(input.size(0), self.num_features, -1) 89 | 90 | # Compute the sum and square-sum. 91 | sum_size = input.size(0) * input.size(2) 92 | input_sum = _sum_ft(input) 93 | input_ssum = _sum_ft(input ** 2) 94 | 95 | # Reduce-and-broadcast the statistics. 96 | if self._parallel_id == 0: 97 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 98 | else: 99 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 100 | 101 | # Compute the output. 102 | if self.affine: 103 | # MJY:: Fuse the multiplication for speed. 104 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 105 | else: 106 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 107 | 108 | # Reshape it. 109 | return output.view(input_shape) 110 | 111 | def __data_parallel_replicate__(self, ctx, copy_id): 112 | self._is_parallel = True 113 | self._parallel_id = copy_id 114 | 115 | # parallel_id == 0 means master device. 116 | if self._parallel_id == 0: 117 | ctx.sync_master = self._sync_master 118 | else: 119 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 120 | 121 | def _data_parallel_master(self, intermediates): 122 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 123 | 124 | # Always using same "device order" makes the ReduceAdd operation faster. 125 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 126 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 127 | 128 | to_reduce = [i[1][:2] for i in intermediates] 129 | to_reduce = [j for i in to_reduce for j in i] # flatten 130 | target_gpus = [i[1].sum.get_device() for i in intermediates] 131 | 132 | sum_size = sum([i[1].sum_size for i in intermediates]) 133 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 134 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 135 | 136 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 137 | 138 | outputs = [] 139 | for i, rec in enumerate(intermediates): 140 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 141 | 142 | return outputs 143 | 144 | def _compute_mean_std(self, sum_, ssum, size): 145 | """Compute the mean and standard-deviation with sum and square-sum. This method 146 | also maintains the moving average on the master device.""" 147 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 148 | mean = sum_ / size 149 | sumvar = ssum - sum_ * mean 150 | unbias_var = sumvar / (size - 1) 151 | bias_var = sumvar / size 152 | 153 | if hasattr(torch, 'no_grad'): 154 | with torch.no_grad(): 155 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 156 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 157 | else: 158 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 159 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 160 | 161 | if SBN_EPS_MODE == 'clamp': 162 | return mean, bias_var.clamp(self.eps) ** -0.5 163 | elif SBN_EPS_MODE == 'plus': 164 | return mean, (bias_var + self.eps) ** -0.5 165 | else: 166 | raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE)) 167 | 168 | 169 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 170 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 171 | mini-batch. 172 | 173 | .. math:: 174 | 175 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 176 | 177 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 178 | standard-deviation are reduced across all devices during training. 179 | 180 | For example, when one uses `nn.DataParallel` to wrap the network during 181 | training, PyTorch's implementation normalize the tensor on each device using 182 | the statistics only on that device, which accelerated the computation and 183 | is also easy to implement, but the statistics might be inaccurate. 184 | Instead, in this synchronized version, the statistics will be computed 185 | over all training samples distributed on multiple devices. 186 | 187 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 188 | as the built-in PyTorch implementation. 189 | 190 | The mean and standard-deviation are calculated per-dimension over 191 | the mini-batches and gamma and beta are learnable parameter vectors 192 | of size C (where C is the input size). 193 | 194 | During training, this layer keeps a running estimate of its computed mean 195 | and variance. The running sum is kept with a default momentum of 0.1. 196 | 197 | During evaluation, this running mean/variance is used for normalization. 198 | 199 | Because the BatchNorm is done over the `C` dimension, computing statistics 200 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 201 | 202 | Args: 203 | num_features: num_features from an expected input of size 204 | `batch_size x num_features [x width]` 205 | eps: a value added to the denominator for numerical stability. 206 | Default: 1e-5 207 | momentum: the value used for the running_mean and running_var 208 | computation. Default: 0.1 209 | affine: a boolean value that when set to ``True``, gives the layer learnable 210 | affine parameters. Default: ``True`` 211 | 212 | Shape:: 213 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 214 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 215 | 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm1d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 2 and input.dim() != 3: 227 | raise ValueError('expected 2D or 3D input (got {}D input)' 228 | .format(input.dim())) 229 | 230 | 231 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 232 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 233 | of 3d inputs 234 | 235 | .. math:: 236 | 237 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 238 | 239 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 240 | standard-deviation are reduced across all devices during training. 241 | 242 | For example, when one uses `nn.DataParallel` to wrap the network during 243 | training, PyTorch's implementation normalize the tensor on each device using 244 | the statistics only on that device, which accelerated the computation and 245 | is also easy to implement, but the statistics might be inaccurate. 246 | Instead, in this synchronized version, the statistics will be computed 247 | over all training samples distributed on multiple devices. 248 | 249 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 250 | as the built-in PyTorch implementation. 251 | 252 | The mean and standard-deviation are calculated per-dimension over 253 | the mini-batches and gamma and beta are learnable parameter vectors 254 | of size C (where C is the input size). 255 | 256 | During training, this layer keeps a running estimate of its computed mean 257 | and variance. The running sum is kept with a default momentum of 0.1. 258 | 259 | During evaluation, this running mean/variance is used for normalization. 260 | 261 | Because the BatchNorm is done over the `C` dimension, computing statistics 262 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 263 | 264 | Args: 265 | num_features: num_features from an expected input of 266 | size batch_size x num_features x height x width 267 | eps: a value added to the denominator for numerical stability. 268 | Default: 1e-5 269 | momentum: the value used for the running_mean and running_var 270 | computation. Default: 0.1 271 | affine: a boolean value that when set to ``True``, gives the layer learnable 272 | affine parameters. Default: ``True`` 273 | 274 | Shape:: 275 | - Input: :math:`(N, C, H, W)` 276 | - Output: :math:`(N, C, H, W)` (same shape as input) 277 | 278 | Examples: 279 | >>> # With Learnable Parameters 280 | >>> m = SynchronizedBatchNorm2d(100) 281 | >>> # Without Learnable Parameters 282 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 283 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 284 | >>> output = m(input) 285 | """ 286 | 287 | def _check_input_dim(self, input): 288 | if input.dim() != 4: 289 | raise ValueError('expected 4D input (got {}D input)' 290 | .format(input.dim())) 291 | 292 | 293 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 294 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 295 | of 4d inputs 296 | 297 | .. math:: 298 | 299 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 300 | 301 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 302 | standard-deviation are reduced across all devices during training. 303 | 304 | For example, when one uses `nn.DataParallel` to wrap the network during 305 | training, PyTorch's implementation normalize the tensor on each device using 306 | the statistics only on that device, which accelerated the computation and 307 | is also easy to implement, but the statistics might be inaccurate. 308 | Instead, in this synchronized version, the statistics will be computed 309 | over all training samples distributed on multiple devices. 310 | 311 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 312 | as the built-in PyTorch implementation. 313 | 314 | The mean and standard-deviation are calculated per-dimension over 315 | the mini-batches and gamma and beta are learnable parameter vectors 316 | of size C (where C is the input size). 317 | 318 | During training, this layer keeps a running estimate of its computed mean 319 | and variance. The running sum is kept with a default momentum of 0.1. 320 | 321 | During evaluation, this running mean/variance is used for normalization. 322 | 323 | Because the BatchNorm is done over the `C` dimension, computing statistics 324 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 325 | or Spatio-temporal BatchNorm 326 | 327 | Args: 328 | num_features: num_features from an expected input of 329 | size batch_size x num_features x depth x height x width 330 | eps: a value added to the denominator for numerical stability. 331 | Default: 1e-5 332 | momentum: the value used for the running_mean and running_var 333 | computation. Default: 0.1 334 | affine: a boolean value that when set to ``True``, gives the layer learnable 335 | affine parameters. Default: ``True`` 336 | 337 | Shape:: 338 | - Input: :math:`(N, C, D, H, W)` 339 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 340 | 341 | Examples: 342 | >>> # With Learnable Parameters 343 | >>> m = SynchronizedBatchNorm3d(100) 344 | >>> # Without Learnable Parameters 345 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 346 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 347 | >>> output = m(input) 348 | """ 349 | 350 | def _check_input_dim(self, input): 351 | if input.dim() != 5: 352 | raise ValueError('expected 5D input (got {}D input)' 353 | .format(input.dim())) 354 | 355 | 356 | @contextlib.contextmanager 357 | def patch_sync_batchnorm(): 358 | import torch.nn as nn 359 | 360 | backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d 361 | 362 | nn.BatchNorm1d = SynchronizedBatchNorm1d 363 | nn.BatchNorm2d = SynchronizedBatchNorm2d 364 | nn.BatchNorm3d = SynchronizedBatchNorm3d 365 | 366 | yield 367 | 368 | nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup 369 | 370 | 371 | def convert_model(module): 372 | """Traverse the input module and its child recursively 373 | and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d 374 | to SynchronizedBatchNorm*N*d 375 | 376 | Args: 377 | module: the input module needs to be convert to SyncBN model 378 | 379 | Examples: 380 | >>> import torch.nn as nn 381 | >>> import torchvision 382 | >>> # m is a standard pytorch model 383 | >>> m = torchvision.models.resnet18(True) 384 | >>> m = nn.DataParallel(m) 385 | >>> # after convert, m is using SyncBN 386 | >>> m = convert_model(m) 387 | """ 388 | if isinstance(module, torch.nn.DataParallel): 389 | mod = module.module 390 | mod = convert_model(mod) 391 | mod = DataParallelWithCallback(mod, device_ids=module.device_ids) 392 | return mod 393 | 394 | mod = module 395 | for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, 396 | torch.nn.modules.batchnorm.BatchNorm2d, 397 | torch.nn.modules.batchnorm.BatchNorm3d], 398 | [SynchronizedBatchNorm1d, 399 | SynchronizedBatchNorm2d, 400 | SynchronizedBatchNorm3d]): 401 | if isinstance(module, pth_module): 402 | mod = sync_module(module.num_features, module.eps, module.momentum, module.affine) 403 | mod.running_mean = module.running_mean 404 | mod.running_var = module.running_var 405 | if module.affine: 406 | mod.weight.data = module.weight.data.clone().detach() 407 | mod.bias.data = module.bias.data.clone().detach() 408 | 409 | for name, child in module.named_children(): 410 | mod.add_module(name, convert_model(child)) 411 | 412 | return mod 413 | -------------------------------------------------------------------------------- /sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message) 29 | 30 | -------------------------------------------------------------------------------- /train_DINet_clip.py: -------------------------------------------------------------------------------- 1 | from models.Discriminator import Discriminator 2 | from models.VGG19 import Vgg19 3 | from models.DINet import DINet 4 | from models.Syncnet import SyncNetPerception 5 | from utils.training_utils import get_scheduler, update_learning_rate,GANLoss 6 | from config.config import DINetTrainingOptions 7 | from sync_batchnorm import convert_model 8 | from torch.utils.data import DataLoader 9 | from dataset.dataset_DINet_clip import DINetDataset 10 | 11 | 12 | import random 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | import os 18 | import torch.nn.functional as F 19 | 20 | if __name__ == "__main__": 21 | ''' 22 | clip training code of DINet 23 | in the resolution you want, using clip training code after frame training 24 | 25 | ''' 26 | # load config 27 | opt = DINetTrainingOptions().parse_args() 28 | random.seed(opt.seed) 29 | np.random.seed(opt.seed) 30 | torch.cuda.manual_seed(opt.seed) 31 | # load training data 32 | train_data = DINetDataset(opt.train_data,opt.augment_num,opt.mouth_region_size) 33 | training_data_loader = DataLoader(dataset=train_data, batch_size=opt.batch_size, shuffle=True,drop_last=True) 34 | train_data_length = len(training_data_loader) 35 | # init network 36 | net_g = DINet(opt.source_channel,opt.ref_channel,opt.audio_channel).cuda() 37 | net_dI = Discriminator(opt.source_channel ,opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda() 38 | net_dV = Discriminator(opt.source_channel * 5, opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda() 39 | net_vgg = Vgg19().cuda() 40 | net_lipsync = SyncNetPerception(opt.pretrained_syncnet_path).cuda() 41 | # parallel 42 | net_g = nn.DataParallel(net_g) 43 | net_g = convert_model(net_g) 44 | net_dI = nn.DataParallel(net_dI) 45 | net_dV = nn.DataParallel(net_dV) 46 | net_vgg = nn.DataParallel(net_vgg) 47 | # setup optimizer 48 | optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr_g) 49 | optimizer_dI = optim.Adam(net_dI.parameters(), lr=opt.lr_dI) 50 | optimizer_dV = optim.Adam(net_dV.parameters(), lr=opt.lr_dI) 51 | ## load frame trained DInet weight 52 | print('loading frame trained DINet weight from: {}'.format(opt.pretrained_frame_DINet_path)) 53 | checkpoint = torch.load(opt.pretrained_frame_DINet_path) 54 | net_g.load_state_dict(checkpoint['state_dict']['net_g']) 55 | # set criterion 56 | criterionGAN = GANLoss().cuda() 57 | criterionL1 = nn.L1Loss().cuda() 58 | criterionMSE = nn.MSELoss().cuda() 59 | # set scheduler 60 | net_g_scheduler = get_scheduler(optimizer_g, opt.non_decay, opt.decay) 61 | net_dI_scheduler = get_scheduler(optimizer_dI, opt.non_decay, opt.decay) 62 | net_dV_scheduler = get_scheduler(optimizer_dV, opt.non_decay, opt.decay) 63 | # set label of syncnet perception loss 64 | real_tensor = torch.tensor(1.0).cuda() 65 | # start train 66 | for epoch in range(opt.start_epoch, opt.non_decay+opt.decay+1): 67 | net_g.train() 68 | for iteration, data in enumerate(training_data_loader): 69 | # forward 70 | source_clip,source_clip_mask, reference_clip,deep_speech_clip,deep_speech_full = data 71 | source_clip = torch.cat(torch.split(source_clip, 1, dim=1), 0).squeeze(1).float().cuda() 72 | source_clip_mask = torch.cat(torch.split(source_clip_mask, 1, dim=1), 0).squeeze(1).float().cuda() 73 | reference_clip = torch.cat(torch.split(reference_clip, 1, dim=1), 0).squeeze(1).float().cuda() 74 | deep_speech_clip = torch.cat(torch.split(deep_speech_clip, 1, dim=1), 0).squeeze(1).float().cuda() 75 | deep_speech_full = deep_speech_full.float().cuda() 76 | fake_out = net_g(source_clip_mask,reference_clip,deep_speech_clip) 77 | fake_out_half = F.avg_pool2d(fake_out, 3, 2, 1, count_include_pad=False) 78 | source_clip_half = F.interpolate(source_clip, scale_factor=0.5, mode='bilinear') 79 | # (1) Update DI network 80 | optimizer_dI.zero_grad() 81 | _,pred_fake_dI = net_dI(fake_out) 82 | loss_dI_fake = criterionGAN(pred_fake_dI, False) 83 | _,pred_real_dI = net_dI(source_clip) 84 | loss_dI_real = criterionGAN(pred_real_dI, True) 85 | # Combined DI loss 86 | loss_dI = (loss_dI_fake + loss_dI_real) * 0.5 87 | loss_dI.backward(retain_graph=True) 88 | optimizer_dI.step() 89 | 90 | # (2) Update DV network 91 | optimizer_dV.zero_grad() 92 | condition_fake_dV = torch.cat(torch.split(fake_out, opt.batch_size, dim=0), 1) 93 | _, pred_fake_dV = net_dV(condition_fake_dV) 94 | loss_dV_fake = criterionGAN(pred_fake_dV, False) 95 | condition_real_dV = torch.cat(torch.split(source_clip, opt.batch_size, dim=0), 1) 96 | _, pred_real_dV = net_dV(condition_real_dV) 97 | loss_dV_real = criterionGAN(pred_real_dV, True) 98 | # Combined DV loss 99 | loss_dV = (loss_dV_fake + loss_dV_real) * 0.5 100 | loss_dV.backward(retain_graph=True) 101 | optimizer_dV.step() 102 | 103 | # (2) Update DINet 104 | _, pred_fake_dI = net_dI(fake_out) 105 | _, pred_fake_dV = net_dV(condition_fake_dV) 106 | optimizer_g.zero_grad() 107 | # compute perception loss 108 | perception_real = net_vgg(source_clip) 109 | perception_fake = net_vgg(fake_out) 110 | perception_real_half = net_vgg(source_clip_half) 111 | perception_fake_half = net_vgg(fake_out_half) 112 | loss_g_perception = 0 113 | for i in range(len(perception_real)): 114 | loss_g_perception += criterionL1(perception_fake[i], perception_real[i]) 115 | loss_g_perception += criterionL1(perception_fake_half[i], perception_real_half[i]) 116 | loss_g_perception = (loss_g_perception / (len(perception_real) * 2)) * opt.lamb_perception 117 | # # gan dI loss 118 | loss_g_dI = criterionGAN(pred_fake_dI, True) 119 | # # gan dV loss 120 | loss_g_dV = criterionGAN(pred_fake_dV, True) 121 | ## sync perception loss 122 | fake_out_clip = torch.cat(torch.split(fake_out, opt.batch_size, dim=0), 1) 123 | fake_out_clip_mouth = fake_out_clip[:, :, train_data.radius:train_data.radius + train_data.mouth_region_size, 124 | train_data.radius_1_4:train_data.radius_1_4 + train_data.mouth_region_size] 125 | sync_score = net_lipsync(fake_out_clip_mouth, deep_speech_full) 126 | loss_sync = criterionMSE(sync_score, real_tensor.expand_as(sync_score)) * opt.lamb_syncnet_perception 127 | # combine all losses 128 | loss_g = loss_g_perception + loss_g_dI +loss_g_dV + loss_sync 129 | loss_g.backward() 130 | optimizer_g.step() 131 | 132 | print( 133 | "===> Epoch[{}]({}/{}): Loss_DI: {:.4f} Loss_GI: {:.4f} Loss_DV: {:.4f} Loss_GV: {:.4f} Loss_perception: {:.4f} Loss_sync: {:.4f} lr_g = {:.7f} ".format( 134 | epoch, iteration, len(training_data_loader), float(loss_dI), float(loss_g_dI),float(loss_dV), float(loss_g_dV), float(loss_g_perception),float(loss_sync), 135 | optimizer_g.param_groups[0]['lr'])) 136 | 137 | update_learning_rate(net_g_scheduler, optimizer_g) 138 | update_learning_rate(net_dI_scheduler, optimizer_dI) 139 | update_learning_rate(net_dV_scheduler, optimizer_dV) 140 | # checkpoint 141 | if epoch % opt.checkpoint == 0: 142 | if not os.path.exists(opt.result_path): 143 | os.mkdir(opt.result_path) 144 | model_out_path = os.path.join(opt.result_path, 'netG_model_epoch_{}.pth'.format(epoch)) 145 | states = { 146 | 'epoch': epoch + 1, 147 | 'state_dict': {'net_g': net_g.state_dict(),'net_dI': net_dI.state_dict(),'net_dV': net_dV.state_dict()}, 148 | 'optimizer': {'net_g': optimizer_g.state_dict(), 'net_dI': optimizer_dI.state_dict(), 'net_dV': optimizer_dV.state_dict()} 149 | } 150 | torch.save(states, model_out_path) 151 | print("Checkpoint saved to {}".format(epoch)) 152 | -------------------------------------------------------------------------------- /train_DINet_frame.py: -------------------------------------------------------------------------------- 1 | from models.Discriminator import Discriminator 2 | from models.VGG19 import Vgg19 3 | from models.DINet import DINet 4 | from utils.training_utils import get_scheduler, update_learning_rate,GANLoss 5 | from torch.utils.data import DataLoader 6 | from dataset.dataset_DINet_frame import DINetDataset 7 | from sync_batchnorm import convert_model 8 | from config.config import DINetTrainingOptions 9 | 10 | import random 11 | import numpy as np 12 | import os 13 | import torch.nn.functional as F 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | 18 | if __name__ == "__main__": 19 | ''' 20 | frame training code of DINet 21 | we use coarse-to-fine training strategy 22 | so you can use this code to train the model in arbitrary resolution 23 | ''' 24 | # load config 25 | opt = DINetTrainingOptions().parse_args() 26 | # set seed 27 | random.seed(opt.seed) 28 | np.random.seed(opt.seed) 29 | torch.cuda.manual_seed(opt.seed) 30 | # load training data in memory 31 | train_data = DINetDataset(opt.train_data,opt.augment_num,opt.mouth_region_size) 32 | training_data_loader = DataLoader(dataset=train_data, batch_size=opt.batch_size, shuffle=True,drop_last=True) 33 | train_data_length = len(training_data_loader) 34 | # init network 35 | net_g = DINet(opt.source_channel,opt.ref_channel,opt.audio_channel).cuda() 36 | net_dI = Discriminator(opt.source_channel ,opt.D_block_expansion, opt.D_num_blocks, opt.D_max_features).cuda() 37 | net_vgg = Vgg19().cuda() 38 | # parallel 39 | net_g = nn.DataParallel(net_g) 40 | net_g = convert_model(net_g) 41 | net_dI = nn.DataParallel(net_dI) 42 | net_vgg = nn.DataParallel(net_vgg) 43 | # setup optimizer 44 | optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr_g) 45 | optimizer_dI = optim.Adam(net_dI.parameters(), lr=opt.lr_dI) 46 | # coarse2fine 47 | if opt.coarse2fine: 48 | print('loading checkpoint for coarse2fine training: {}'.format(opt.coarse_model_path)) 49 | checkpoint = torch.load(opt.coarse_model_path) 50 | net_g.load_state_dict(checkpoint['state_dict']['net_g']) 51 | # set criterion 52 | criterionGAN = GANLoss().cuda() 53 | criterionL1 = nn.L1Loss().cuda() 54 | # set scheduler 55 | net_g_scheduler = get_scheduler(optimizer_g, opt.non_decay, opt.decay) 56 | net_dI_scheduler = get_scheduler(optimizer_dI, opt.non_decay, opt.decay) 57 | # start train 58 | for epoch in range(opt.start_epoch, opt.non_decay+opt.decay+1): 59 | net_g.train() 60 | for iteration, data in enumerate(training_data_loader): 61 | # read data 62 | source_image_data,source_image_mask, reference_clip_data,deepspeech_feature = data 63 | source_image_data = source_image_data.float().cuda() 64 | source_image_mask = source_image_mask.float().cuda() 65 | reference_clip_data = reference_clip_data.float().cuda() 66 | deepspeech_feature = deepspeech_feature.float().cuda() 67 | # network forward 68 | fake_out = net_g(source_image_mask,reference_clip_data,deepspeech_feature) 69 | # down sample output image and real image 70 | fake_out_half = F.avg_pool2d(fake_out, 3, 2, 1, count_include_pad=False) 71 | target_tensor_half = F.interpolate(source_image_data, scale_factor=0.5, mode='bilinear') 72 | # (1) Update D network 73 | optimizer_dI.zero_grad() 74 | # compute fake loss 75 | _,pred_fake_dI = net_dI(fake_out) 76 | loss_dI_fake = criterionGAN(pred_fake_dI, False) 77 | # compute real loss 78 | _,pred_real_dI = net_dI(source_image_data) 79 | loss_dI_real = criterionGAN(pred_real_dI, True) 80 | # Combined DI loss 81 | loss_dI = (loss_dI_fake + loss_dI_real) * 0.5 82 | loss_dI.backward(retain_graph=True) 83 | optimizer_dI.step() 84 | # (2) Update G network 85 | _, pred_fake_dI = net_dI(fake_out) 86 | optimizer_g.zero_grad() 87 | # compute perception loss 88 | perception_real = net_vgg(source_image_data) 89 | perception_fake = net_vgg(fake_out) 90 | perception_real_half = net_vgg(target_tensor_half) 91 | perception_fake_half = net_vgg(fake_out_half) 92 | loss_g_perception = 0 93 | for i in range(len(perception_real)): 94 | loss_g_perception += criterionL1(perception_fake[i], perception_real[i]) 95 | loss_g_perception += criterionL1(perception_fake_half[i], perception_real_half[i]) 96 | loss_g_perception = (loss_g_perception / (len(perception_real) * 2)) * opt.lamb_perception 97 | # # gan dI loss 98 | loss_g_dI = criterionGAN(pred_fake_dI, True) 99 | # combine perception loss and gan loss 100 | loss_g = loss_g_perception + loss_g_dI 101 | loss_g.backward() 102 | optimizer_g.step() 103 | 104 | print( 105 | "===> Epoch[{}]({}/{}): Loss_DI: {:.4f} Loss_GI: {:.4f} Loss_perception: {:.4f} lr_g = {:.7f} ".format( 106 | epoch, iteration, len(training_data_loader), float(loss_dI), float(loss_g_dI), float(loss_g_perception),optimizer_g.param_groups[0]['lr'])) 107 | 108 | update_learning_rate(net_g_scheduler, optimizer_g) 109 | update_learning_rate(net_dI_scheduler, optimizer_dI) 110 | #checkpoint 111 | if epoch % opt.checkpoint == 0: 112 | if not os.path.exists(opt.result_path): 113 | os.mkdir(opt.result_path) 114 | model_out_path = os.path.join(opt.result_path, 'netG_model_epoch_{}.pth'.format(epoch)) 115 | states = { 116 | 'epoch': epoch + 1, 117 | 'state_dict': {'net_g': net_g.state_dict(), 'net_dI': net_dI.state_dict()},# 118 | 'optimizer': {'net_g': optimizer_g.state_dict(), 'net_dI': optimizer_dI.state_dict()}# 119 | } 120 | torch.save(states, model_out_path) 121 | print("Checkpoint saved to {}".format(epoch)) 122 | -------------------------------------------------------------------------------- /utils/data_processing.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import random 4 | 5 | 6 | def load_landmark_openface(csv_path): 7 | ''' 8 | load openface landmark from .csv file 9 | ''' 10 | with open(csv_path, 'r') as f: 11 | reader = csv.reader(f) 12 | data_all = [row for row in reader] 13 | x_list = [] 14 | y_list = [] 15 | for row_index,row in enumerate(data_all[1:]): 16 | frame_num = float(row[0]) 17 | if int(frame_num)!= row_index+1: 18 | return None 19 | x_list.append([float(x) for x in row[5:5+68]]) 20 | y_list.append([float(y) for y in row[5+68:5+68 + 68]]) 21 | x_array = np.array(x_list) 22 | y_array = np.array(y_list) 23 | landmark_array = np.stack([x_array,y_array],2) 24 | return landmark_array 25 | 26 | 27 | def compute_crop_radius(video_size,landmark_data_clip,random_scale = None): 28 | ''' 29 | judge if crop face and compute crop radius 30 | ''' 31 | video_w, video_h = video_size[0], video_size[1] 32 | landmark_max_clip = np.max(landmark_data_clip, axis=1) 33 | if random_scale is None: 34 | random_scale = random.random() / 10 + 1.05 35 | else: 36 | random_scale = random_scale 37 | radius_h = (landmark_max_clip[:,1] - landmark_data_clip[:,29, 1]) * random_scale 38 | radius_w = (landmark_data_clip[:,54, 0] - landmark_data_clip[:,48, 0]) * random_scale 39 | radius_clip = np.max(np.stack([radius_h, radius_w],1),1) // 2 40 | radius_max = np.max(radius_clip) 41 | radius_max = (np.int(radius_max/4) + 1 ) * 4 42 | radius_max_1_4 = radius_max//4 43 | clip_min_h = landmark_data_clip[:, 29, 44 | 1] - radius_max 45 | clip_max_h = landmark_data_clip[:, 29, 46 | 1] + radius_max * 2 + radius_max_1_4 47 | clip_min_w = landmark_data_clip[:, 33, 48 | 0] - radius_max - radius_max_1_4 49 | clip_max_w = landmark_data_clip[:, 33, 50 | 0] + radius_max + radius_max_1_4 51 | if min(clip_min_h.tolist() + clip_min_w.tolist()) < 0: 52 | return False,None 53 | elif max(clip_max_h.tolist()) > video_h: 54 | return False,None 55 | elif max(clip_max_w.tolist()) > video_w: 56 | return False,None 57 | elif max(radius_clip) > min(radius_clip) * 1.5: 58 | return False, None 59 | else: 60 | return True,radius_max -------------------------------------------------------------------------------- /utils/deep_speech.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import warnings 4 | import resampy 5 | from scipy.io import wavfile 6 | from python_speech_features import mfcc 7 | import tensorflow as tf 8 | 9 | 10 | class DeepSpeech(): 11 | def __init__(self,model_path): 12 | self.graph, self.logits_ph, self.input_node_ph, self.input_lengths_ph \ 13 | = self._prepare_deepspeech_net(model_path) 14 | self.target_sample_rate = 16000 15 | 16 | def _prepare_deepspeech_net(self,deepspeech_pb_path): 17 | with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f: 18 | graph_def = tf.compat.v1.GraphDef() 19 | graph_def.ParseFromString(f.read()) 20 | graph = tf.compat.v1.get_default_graph() 21 | tf.import_graph_def(graph_def, name="deepspeech") 22 | logits_ph = graph.get_tensor_by_name("deepspeech/logits:0") 23 | input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0") 24 | input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0") 25 | 26 | return graph, logits_ph, input_node_ph, input_lengths_ph 27 | 28 | def conv_audio_to_deepspeech_input_vector(self,audio, 29 | sample_rate, 30 | num_cepstrum, 31 | num_context): 32 | # Get mfcc coefficients: 33 | features = mfcc( 34 | signal=audio, 35 | samplerate=sample_rate, 36 | numcep=num_cepstrum) 37 | 38 | # We only keep every second feature (BiRNN stride = 2): 39 | features = features[::2] 40 | 41 | # One stride per time step in the input: 42 | num_strides = len(features) 43 | 44 | # Add empty initial and final contexts: 45 | empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype) 46 | features = np.concatenate((empty_context, features, empty_context)) 47 | 48 | # Create a view into the array with overlapping strides of size 49 | # numcontext (past) + 1 (present) + numcontext (future): 50 | window_size = 2 * num_context + 1 51 | train_inputs = np.lib.stride_tricks.as_strided( 52 | features, 53 | shape=(num_strides, window_size, num_cepstrum), 54 | strides=(features.strides[0], 55 | features.strides[0], features.strides[1]), 56 | writeable=False) 57 | 58 | # Flatten the second and third dimensions: 59 | train_inputs = np.reshape(train_inputs, [num_strides, -1]) 60 | 61 | train_inputs = np.copy(train_inputs) 62 | train_inputs = (train_inputs - np.mean(train_inputs)) / \ 63 | np.std(train_inputs) 64 | 65 | return train_inputs 66 | 67 | def compute_audio_feature(self,audio_path): 68 | audio_sample_rate, audio = wavfile.read(audio_path) 69 | if audio.ndim != 1: 70 | warnings.warn( 71 | "Audio has multiple channels, the first channel is used") 72 | audio = audio[:, 0] 73 | if audio_sample_rate != self.target_sample_rate: 74 | resampled_audio = resampy.resample( 75 | x=audio.astype(np.float), 76 | sr_orig=audio_sample_rate, 77 | sr_new=self.target_sample_rate) 78 | else: 79 | resampled_audio = audio.astype(np.float) 80 | with tf.compat.v1.Session(graph=self.graph) as sess: 81 | input_vector = self.conv_audio_to_deepspeech_input_vector( 82 | audio=resampled_audio.astype(np.int16), 83 | sample_rate=self.target_sample_rate, 84 | num_cepstrum=26, 85 | num_context=9) 86 | network_output = sess.run( 87 | self.logits_ph, 88 | feed_dict={ 89 | self.input_node_ph: input_vector[np.newaxis, ...], 90 | self.input_lengths_ph: [input_vector.shape[0]]}) 91 | ds_features = network_output[::2,0,:] 92 | return ds_features 93 | 94 | if __name__ == '__main__': 95 | audio_path = r'./00168.wav' 96 | model_path = r'./output_graph.pb' 97 | DSModel = DeepSpeech(model_path) 98 | ds_feature = DSModel.compute_audio_feature(audio_path) 99 | print(ds_feature) -------------------------------------------------------------------------------- /utils/training_utils.py: -------------------------------------------------------------------------------- 1 | from torch.optim import lr_scheduler 2 | import torch.nn as nn 3 | import torch 4 | 5 | def get_scheduler(optimizer, niter,niter_decay,lr_policy='lambda',lr_decay_iters=50): 6 | ''' 7 | scheduler in training stage 8 | ''' 9 | if lr_policy == 'lambda': 10 | def lambda_rule(epoch): 11 | lr_l = 1.0 - max(0, epoch - niter) / float(niter_decay + 1) 12 | return lr_l 13 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 14 | elif lr_policy == 'step': 15 | scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_decay_iters, gamma=0.1) 16 | elif lr_policy == 'plateau': 17 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 18 | elif lr_policy == 'cosine': 19 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=niter, eta_min=0) 20 | else: 21 | return NotImplementedError('learning rate policy [%s] is not implemented', lr_policy) 22 | return scheduler 23 | 24 | def update_learning_rate(scheduler, optimizer): 25 | scheduler.step() 26 | lr = optimizer.param_groups[0]['lr'] 27 | print('learning rate = %.7f' % lr) 28 | 29 | class GANLoss(nn.Module): 30 | ''' 31 | GAN loss 32 | ''' 33 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 34 | super(GANLoss, self).__init__() 35 | self.register_buffer('real_label', torch.tensor(target_real_label)) 36 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 37 | if use_lsgan: 38 | self.loss = nn.MSELoss() 39 | else: 40 | self.loss = nn.BCELoss() 41 | 42 | def get_target_tensor(self, input, target_is_real): 43 | if target_is_real: 44 | target_tensor = self.real_label 45 | else: 46 | target_tensor = self.fake_label 47 | return target_tensor.expand_as(input) 48 | 49 | def forward(self, input, target_is_real): 50 | target_tensor = self.get_target_tensor(input, target_is_real) 51 | return self.loss(input, target_tensor) --------------------------------------------------------------------------------