├── README.md ├── config ├── parameters.yaml └── vox-256.yaml ├── demo ├── audio │ └── intro.wav └── img │ ├── baiden.jpg │ ├── masike.jpg │ ├── obama.jpg │ ├── paint.jpg │ ├── paint1.jpg │ ├── paint2.jpg │ ├── statue.jpg │ ├── trump.jpg │ └── trump2.jpg ├── inference.py ├── modules ├── audio2kp.py ├── audio2pose.py ├── dense_motion.py ├── generator.py ├── keypoint_detector.py ├── resnet.py └── util.py ├── requirements.txt └── sync_batchnorm ├── __init__.py ├── batchnorm.py ├── comm.py ├── replicate.py └── unittest.py /README.md: -------------------------------------------------------------------------------- 1 | # Audio2Head: Audio-driven One-shot Talking-head Generation with Natural Head Motion (IJCAI 2021) 2 | 3 | #### [Paper](https://www.ijcai.org/proceedings/2021/0152.pdf) | [Demo](https://www.youtube.com/watch?v=xvcBJ29l8rA) 4 | 5 | #### Requirements 6 | 7 | - Python 3.6 , Pytorch >= 1.6 and ffmpeg 8 | 9 | - Other requirements are listed in the 'requirements.txt' 10 | 11 | 12 | 13 | #### Pretrained Checkpoint 14 | 15 | Please download the pretrained checkpoint from [google-drive](https://drive.google.com/file/d/1tvI43ZIrnx9Ti2TpFiEO4dK5DOwcECD7/view?usp=sharing) and put it within the folder (`/checkpoints`). 16 | 17 | 18 | 19 | #### Generate Demo Results 20 | 21 | ``` 22 | python inference.py --audio_path xxx.wav --img_path xxx.jpg 23 | ``` 24 | 25 | Note that the input images must keep the same height and width and the face should be appropriately cropped as in `/demo/img`. 26 | 27 | 28 | 29 | #### License and Citation 30 | 31 | ``` 32 | @InProceedings{wang2021audio2head, 33 | author = Suzhen Wang, Lincheng Li, Yu Ding, Changjie Fan, Xin Yu 34 | title = {Audio2Head: Audio-driven One-shot Talking-head Generation with Natural Head Motion}, 35 | booktitle = {the 30th International Joint Conference on Artificial Intelligence (IJCAI-21)}, 36 | year = {2021}, 37 | } 38 | ``` 39 | 40 | 41 | 42 | #### Acknowledgement 43 | 44 | This codebase is based on [First Order Motion Model](https://github.com/AliaksandrSiarohin/first-order-model), thanks for their contribution. 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /config/parameters.yaml: -------------------------------------------------------------------------------- 1 | block_expansion: 32 2 | estimate_jacobian: true 3 | max_features: 512 4 | num_blocks: 5 5 | num_kp: 10 6 | num_w: 2 7 | seq: true 8 | seq_len: 64 -------------------------------------------------------------------------------- /config/vox-256.yaml: -------------------------------------------------------------------------------- 1 | dataset_params: 2 | root_dir: /root/ 3 | frame_shape: [256, 256, 3] 4 | id_sampling: True 5 | pairs_list: data/vox256.csv 6 | augmentation_params: 7 | flip_param: 8 | horizontal_flip: True 9 | time_flip: True 10 | jitter_param: 11 | brightness: 0.1 12 | contrast: 0.1 13 | saturation: 0.1 14 | hue: 0.1 15 | 16 | 17 | model_params: 18 | common_params: 19 | num_kp: 10 20 | num_channels: 3 21 | estimate_jacobian: True 22 | kp_detector_params: 23 | temperature: 0.1 24 | block_expansion: 32 25 | max_features: 1024 26 | scale_factor: 0.25 27 | num_blocks: 5 28 | generator_params: 29 | block_expansion: 64 30 | max_features: 512 31 | num_down_blocks: 2 32 | num_bottleneck_blocks: 6 33 | estimate_occlusion_map: True 34 | dense_motion_params: 35 | block_expansion: 64 36 | max_features: 1024 37 | num_blocks: 5 38 | scale_factor: 0.25 39 | discriminator_params: 40 | scales: [1] 41 | block_expansion: 32 42 | max_features: 512 43 | num_blocks: 4 44 | sn: True 45 | 46 | train_params: 47 | num_epochs: 100 48 | num_repeats: 50 49 | epoch_milestones: [5, 20, 30] 50 | lr_generator: 2.0e-4 51 | lr_discriminator: 2.0e-4 52 | lr_kp_detector: 2.0e-4 53 | batch_size: 36 54 | scales: [1, 0.5, 0.25, 0.125] 55 | checkpoint_freq: 10 56 | transform_params: 57 | sigma_affine: 0.05 58 | sigma_tps: 0.005 59 | points_tps: 5 60 | loss_weights: 61 | generator_gan: 0 62 | discriminator_gan: 1 63 | feature_matching: [10, 10, 10, 10] 64 | perceptual: [10, 10, 10, 10, 10] 65 | equivariance_value: 10 66 | equivariance_jacobian: 10 67 | 68 | reconstruction_params: 69 | num_videos: 1000 70 | format: '.mp4' 71 | 72 | animate_params: 73 | num_pairs: 50 74 | format: '.mp4' 75 | normalization_params: 76 | adapt_movement_scale: False 77 | use_relative_movement: True 78 | use_relative_jacobian: True 79 | 80 | visualizer_params: 81 | kp_size: 5 82 | draw_border: True 83 | colormap: 'gist_rainbow' 84 | -------------------------------------------------------------------------------- /demo/audio/intro.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/Audio2Head/09e9b431e48a6358c2877a12cd45457ff0379455/demo/audio/intro.wav -------------------------------------------------------------------------------- /demo/img/baiden.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/Audio2Head/09e9b431e48a6358c2877a12cd45457ff0379455/demo/img/baiden.jpg -------------------------------------------------------------------------------- /demo/img/masike.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/Audio2Head/09e9b431e48a6358c2877a12cd45457ff0379455/demo/img/masike.jpg -------------------------------------------------------------------------------- /demo/img/obama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/Audio2Head/09e9b431e48a6358c2877a12cd45457ff0379455/demo/img/obama.jpg -------------------------------------------------------------------------------- /demo/img/paint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/Audio2Head/09e9b431e48a6358c2877a12cd45457ff0379455/demo/img/paint.jpg -------------------------------------------------------------------------------- /demo/img/paint1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/Audio2Head/09e9b431e48a6358c2877a12cd45457ff0379455/demo/img/paint1.jpg -------------------------------------------------------------------------------- /demo/img/paint2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/Audio2Head/09e9b431e48a6358c2877a12cd45457ff0379455/demo/img/paint2.jpg -------------------------------------------------------------------------------- /demo/img/statue.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/Audio2Head/09e9b431e48a6358c2877a12cd45457ff0379455/demo/img/statue.jpg -------------------------------------------------------------------------------- /demo/img/trump.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/Audio2Head/09e9b431e48a6358c2877a12cd45457ff0379455/demo/img/trump.jpg -------------------------------------------------------------------------------- /demo/img/trump2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuxiVirtualHuman/Audio2Head/09e9b431e48a6358c2877a12cd45457ff0379455/demo/img/trump2.jpg -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import python_speech_features 4 | from scipy.io import wavfile 5 | from scipy.interpolate import interp1d 6 | import numpy as np 7 | import pyworld 8 | import torch 9 | from modules.audio2pose import get_pose_from_audio 10 | from skimage import io, img_as_float32 11 | import cv2 12 | from modules.generator import OcclusionAwareGenerator 13 | from modules.keypoint_detector import KPDetector 14 | from modules.audio2kp import AudioModel3D 15 | import yaml,os,imageio 16 | 17 | def draw_annotation_box( image, rotation_vector, translation_vector, color=(255, 255, 255), line_width=2): 18 | """Draw a 3D box as annotation of pose""" 19 | 20 | camera_matrix = np.array( 21 | [[233.333, 0, 128], 22 | [0, 233.333, 128], 23 | [0, 0, 1]], dtype="double") 24 | 25 | dist_coeefs = np.zeros((4, 1)) 26 | 27 | point_3d = [] 28 | rear_size = 75 29 | rear_depth = 0 30 | point_3d.append((-rear_size, -rear_size, rear_depth)) 31 | point_3d.append((-rear_size, rear_size, rear_depth)) 32 | point_3d.append((rear_size, rear_size, rear_depth)) 33 | point_3d.append((rear_size, -rear_size, rear_depth)) 34 | point_3d.append((-rear_size, -rear_size, rear_depth)) 35 | 36 | front_size = 100 37 | front_depth = 100 38 | point_3d.append((-front_size, -front_size, front_depth)) 39 | point_3d.append((-front_size, front_size, front_depth)) 40 | point_3d.append((front_size, front_size, front_depth)) 41 | point_3d.append((front_size, -front_size, front_depth)) 42 | point_3d.append((-front_size, -front_size, front_depth)) 43 | point_3d = np.array(point_3d, dtype=np.float).reshape(-1, 3) 44 | 45 | # Map to 2d image points 46 | (point_2d, _) = cv2.projectPoints(point_3d, 47 | rotation_vector, 48 | translation_vector, 49 | camera_matrix, 50 | dist_coeefs) 51 | point_2d = np.int32(point_2d.reshape(-1, 2)) 52 | 53 | # Draw all the lines 54 | cv2.polylines(image, [point_2d], True, color, line_width, cv2.LINE_AA) 55 | cv2.line(image, tuple(point_2d[1]), tuple( 56 | point_2d[6]), color, line_width, cv2.LINE_AA) 57 | cv2.line(image, tuple(point_2d[2]), tuple( 58 | point_2d[7]), color, line_width, cv2.LINE_AA) 59 | cv2.line(image, tuple(point_2d[3]), tuple( 60 | point_2d[8]), color, line_width, cv2.LINE_AA) 61 | 62 | def inter_pitch(y,y_flag): 63 | frame_num = y.shape[0] 64 | i = 0 65 | last = -1 66 | while(i= frame_num: 79 | break 80 | elif last == -1: 81 | y[:i] = y[i] 82 | else: 83 | inter_num = i-last+1 84 | fy = np.array([y[last],y[i]]) 85 | fx = np.linspace(0, 1, num=2) 86 | f = interp1d(fx,fy) 87 | fx_new = np.linspace(0,1,inter_num) 88 | fy_new = f(fx_new) 89 | y[last+1:i] = fy_new[1:-1] 90 | last = i 91 | i+=1 92 | 93 | else: 94 | last = i 95 | i+=1 96 | return y 97 | 98 | def get_audio_feature_from_audio(audio_path,norm = True): 99 | sample_rate, audio = wavfile.read(audio_path) 100 | if len(audio.shape) == 2: 101 | if np.min(audio[:, 0]) <= 0: 102 | audio = audio[:, 1] 103 | else: 104 | audio = audio[:, 0] 105 | if norm: 106 | audio = audio - np.mean(audio) 107 | audio = audio / np.max(np.abs(audio)) 108 | a = python_speech_features.mfcc(audio, sample_rate) 109 | b = python_speech_features.logfbank(audio, sample_rate) 110 | c, _ = pyworld.harvest(audio, sample_rate, frame_period=10) 111 | c_flag = (c == 0.0) ^ 1 112 | c = inter_pitch(c, c_flag) 113 | c = np.expand_dims(c, axis=1) 114 | c_flag = np.expand_dims(c_flag, axis=1) 115 | frame_num = np.min([a.shape[0], b.shape[0], c.shape[0]]) 116 | 117 | cat = np.concatenate([a[:frame_num], b[:frame_num], c[:frame_num], c_flag[:frame_num]], axis=1) 118 | return cat 119 | 120 | def audio2head(audio_path, img_path, model_path, save_path): 121 | temp_audio="./results/temp.wav" 122 | command = ("ffmpeg -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (audio_path, temp_audio)) 123 | output = subprocess.call(command, shell=True, stdout=None) 124 | 125 | audio_feature = get_audio_feature_from_audio(temp_audio) 126 | frames = len(audio_feature) // 4 127 | 128 | img = io.imread(img_path)[:, :, :3] 129 | img = cv2.resize(img, (256, 256)) 130 | 131 | img = np.array(img_as_float32(img)) 132 | img = img.transpose((2, 0, 1)) 133 | img = torch.from_numpy(img).unsqueeze(0).cuda() 134 | 135 | 136 | ref_pose_rot, ref_pose_trans = get_pose_from_audio(img, audio_feature, model_path) 137 | torch.cuda.empty_cache() 138 | 139 | config_file = r"./config/vox-256.yaml" 140 | with open(config_file) as f: 141 | config = yaml.load(f) 142 | kp_detector = KPDetector(**config['model_params']['kp_detector_params'], 143 | **config['model_params']['common_params']) 144 | generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], 145 | **config['model_params']['common_params']) 146 | kp_detector = kp_detector.cuda() 147 | generator = generator.cuda() 148 | 149 | opt = argparse.Namespace(**yaml.load(open("./config/parameters.yaml"))) 150 | audio2kp = AudioModel3D(opt).cuda() 151 | 152 | checkpoint = torch.load(model_path) 153 | kp_detector.load_state_dict(checkpoint["kp_detector"]) 154 | generator.load_state_dict(checkpoint["generator"]) 155 | audio2kp.load_state_dict(checkpoint["audio2kp"]) 156 | 157 | generator.eval() 158 | kp_detector.eval() 159 | audio2kp.eval() 160 | 161 | audio_f = [] 162 | poses = [] 163 | pad = np.zeros((4,41),dtype=np.float32) 164 | for i in range(0, frames, opt.seq_len // 2): 165 | temp_audio = [] 166 | temp_pos = [] 167 | for j in range(opt.seq_len): 168 | if i + j < frames: 169 | temp_audio.append(audio_feature[(i+j)*4:(i+j)*4+4]) 170 | trans = ref_pose_trans[i + j] 171 | rot = ref_pose_rot[i + j] 172 | else: 173 | temp_audio.append(pad) 174 | trans = ref_pose_trans[-1] 175 | rot = ref_pose_rot[-1] 176 | 177 | pose = np.zeros([256, 256]) 178 | draw_annotation_box(pose, np.array(rot), np.array(trans)) 179 | temp_pos.append(pose) 180 | audio_f.append(temp_audio) 181 | poses.append(temp_pos) 182 | 183 | audio_f = torch.from_numpy(np.array(audio_f,dtype=np.float32)).unsqueeze(0) 184 | poses = torch.from_numpy(np.array(poses, dtype=np.float32)).unsqueeze(0) 185 | 186 | bs = audio_f.shape[1] 187 | predictions_gen = [] 188 | total_frames = 0 189 | 190 | for bs_idx in range(bs): 191 | t = {} 192 | 193 | t["audio"] = audio_f[:, bs_idx].cuda() 194 | t["pose"] = poses[:, bs_idx].cuda() 195 | t["id_img"] = img 196 | kp_gen_source = kp_detector(img) 197 | 198 | gen_kp = audio2kp(t) 199 | if bs_idx == 0: 200 | startid = 0 201 | end_id = opt.seq_len // 4 * 3 202 | else: 203 | startid = opt.seq_len // 4 204 | end_id = opt.seq_len // 4 * 3 205 | 206 | for frame_bs_idx in range(startid, end_id): 207 | tt = {} 208 | tt["value"] = gen_kp["value"][:, frame_bs_idx] 209 | if opt.estimate_jacobian: 210 | tt["jacobian"] = gen_kp["jacobian"][:, frame_bs_idx] 211 | out_gen = generator(img, kp_source=kp_gen_source, kp_driving=tt) 212 | out_gen["kp_source"] = kp_gen_source 213 | out_gen["kp_driving"] = tt 214 | del out_gen['sparse_deformed'] 215 | del out_gen['occlusion_map'] 216 | del out_gen['deformed'] 217 | predictions_gen.append( 218 | (np.transpose(out_gen['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] * 255).astype(np.uint8)) 219 | 220 | total_frames += 1 221 | if total_frames >= frames: 222 | break 223 | if total_frames >= frames: 224 | break 225 | 226 | log_dir = save_path 227 | if not os.path.exists(os.path.join(log_dir, "temp")): 228 | os.makedirs(os.path.join(log_dir, "temp")) 229 | image_name = os.path.basename(img_path)[:-4]+ "_" + os.path.basename(audio_path)[:-4] + ".mp4" 230 | 231 | video_path = os.path.join(log_dir, "temp", image_name) 232 | 233 | imageio.mimsave(video_path, predictions_gen, fps=25.0) 234 | 235 | save_video = os.path.join(log_dir, image_name) 236 | cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video_path, audio_path, save_video) 237 | os.system(cmd) 238 | os.remove(video_path) 239 | 240 | 241 | if __name__ == '__main__': 242 | parser = argparse.ArgumentParser() 243 | parser.add_argument("--audio_path",default=r"./demo/audio/intro.wav",help="audio file sampled as 16k hz") 244 | parser.add_argument("--img_path",default=r"./demo/img/paint.jpg", help="reference image") 245 | parser.add_argument("--save_path",default=r"./results", help="save path") 246 | parser.add_argument("--model_path",default=r"./checkpoints/audio2head.pth.tar", help="pretrained model path") 247 | 248 | parse = parser.parse_args() 249 | 250 | os.makedirs(parse.save_path,exist_ok=True) 251 | audio2head(parse.audio_path,parse.img_path,parse.model_path,parse.save_path) -------------------------------------------------------------------------------- /modules/audio2kp.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from modules.util import AntiAliasInterpolation2d 5 | from modules.util import Hourglass3D 6 | 7 | from modules.util import gaussian2kp 8 | from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d 9 | 10 | 11 | class AudioModel3D(nn.Module): 12 | def __init__(self,opt): 13 | super(AudioModel3D,self).__init__() 14 | self.opt = opt 15 | self.seq_len = opt.seq_len 16 | self.pad = 0 17 | 18 | self.down_id = AntiAliasInterpolation2d(3,0.25) 19 | self.down_pose = AntiAliasInterpolation2d(opt.seq_len,0.25) 20 | 21 | self.embedding = nn.Sequential(nn.ConvTranspose2d(1, 8, (29, 14), stride=(1, 1), padding=(0, 11)), 22 | BatchNorm2d(8), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(8, 2, (13, 13), stride=(1, 1), padding=(6, 6))) 25 | 26 | num_channels = 6 27 | self.predictor = Hourglass3D(opt.block_expansion, in_features=num_channels, 28 | max_features=opt.max_features, num_blocks=opt.num_blocks) 29 | 30 | self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=opt.num_kp, kernel_size=(7, 7, 7), 31 | padding=(3,0,0)) 32 | if opt.estimate_jacobian: 33 | self.num_jacobian_maps = opt.num_kp 34 | self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, 35 | out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=(0,0)) 36 | self.jacobian.weight.data.zero_() 37 | self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) 38 | else: 39 | self.jacobian = None 40 | 41 | self.temperature = 0.1 42 | 43 | 44 | def forward(self, x): 45 | bs,_,_,c_dim = x["audio"].shape 46 | 47 | audio_embedding = self.embedding(x["audio"].reshape(-1,1,4,c_dim)) 48 | audio_embedding = F.interpolate(audio_embedding,scale_factor=2).reshape(bs,self.opt.seq_len,2,64,64).permute(0,2,1,3,4) 49 | 50 | id_feature = self.down_id(x["id_img"]) 51 | pose_feature = self.down_pose(x["pose"]) 52 | 53 | embeddings = torch.cat([audio_embedding,id_feature.unsqueeze(2).repeat(1,1,self.opt.seq_len,1,1),pose_feature.unsqueeze(1)],dim=1) 54 | 55 | feature_map = self.predictor(embeddings) 56 | feature_shape = feature_map.shape 57 | prediction = self.kp(feature_map).permute(0,2,1,3,4) 58 | prediction = prediction.reshape(-1,prediction.shape[2],prediction.shape[3],prediction.shape[4]) 59 | final_shape = prediction.shape 60 | heatmap = prediction.view(final_shape[0], final_shape[1], -1) 61 | heatmap = F.softmax(heatmap / self.temperature, dim=2) 62 | heatmap = heatmap.view(*final_shape) 63 | 64 | out = gaussian2kp(heatmap) 65 | out["value"] = out["value"].reshape(-1,self.opt.seq_len,self.opt.num_kp,2) 66 | if self.jacobian is not None: 67 | jacobian_map = self.jacobian(feature_map.permute(0,2,1,3,4).reshape(-1, feature_shape[1],feature_shape[3],feature_shape[4])) 68 | 69 | jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], 70 | final_shape[3]) 71 | out["jacobian_map"] = jacobian_map 72 | heatmap = heatmap.unsqueeze(2) 73 | 74 | jacobian = heatmap * jacobian_map 75 | jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) 76 | jacobian = jacobian.sum(dim=-1) 77 | jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) 78 | out['jacobian'] = jacobian.reshape(-1,self.seq_len,self.opt.num_kp,2,2) 79 | 80 | out["pred_fature"] = prediction 81 | return out 82 | 83 | -------------------------------------------------------------------------------- /modules/audio2pose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modules.util import MyResNet34 4 | import numpy as np 5 | 6 | class audio2poseLSTM(nn.Module): 7 | def __init__(self): 8 | super(audio2poseLSTM,self).__init__() 9 | 10 | self.em_audio = MyResNet34(256, 1) 11 | self.em_img = MyResNet34(256, 3) 12 | 13 | self.lstm = nn.LSTM(512,256,num_layers=2,bias=True,batch_first=True) 14 | self.output = nn.Linear(256,6) 15 | 16 | def forward(self,x): 17 | img_em = self.em_img(x['img']) 18 | result = [self.output(img_em).unsqueeze(1)] 19 | bs,seqlen,_,_ = x["audio"].shape 20 | zero_state = torch.zeros((2,bs,256),requires_grad=True).to(img_em.device) 21 | cur_state = (zero_state,zero_state) 22 | audio = x["audio"].reshape(-1, 1, 4, 41) 23 | audio_em = self.em_audio(audio).reshape(bs, seqlen, 256) 24 | for i in range(seqlen): 25 | 26 | img_em,cur_state = self.lstm(torch.cat((audio_em[:,i:i+1],img_em.unsqueeze(1)),dim=2),cur_state) 27 | img_em = img_em.reshape(-1, 256) 28 | result.append(self.output(img_em).unsqueeze(1)) 29 | res = torch.cat(result,dim=1) 30 | return res 31 | 32 | def get_pose_from_audio(img,audio,model_path): 33 | num_frame = len(audio) // 4 34 | minv = np.array([-0.639, -0.501, -0.47, -102.6, -32.5, 184.6], dtype=np.float32) 35 | maxv = np.array([0.411, 0.547, 0.433, 159.1, 116.5, 376.5], dtype=np.float32) 36 | 37 | 38 | generator = audio2poseLSTM().cuda() 39 | 40 | ckpt_para = torch.load(model_path) 41 | 42 | generator.load_state_dict(ckpt_para["audio2pose"]) 43 | generator.eval() 44 | 45 | audio_seq = [] 46 | for i in range(num_frame): 47 | audio_seq.append(audio[i*4:i*4+4]) 48 | 49 | audio = torch.from_numpy(np.array(audio_seq,dtype=np.float32)).unsqueeze(0).cuda() 50 | 51 | x = {} 52 | x["img"] = img 53 | x["audio"] = audio 54 | poses = generator(x) 55 | 56 | print(poses.shape) 57 | poses = poses.cpu().data.numpy()[0] 58 | 59 | poses = (poses+1)/2*(maxv-minv)+minv 60 | rot,trans = poses[:,:3].copy(),poses[:,3:].copy() 61 | return rot,trans -------------------------------------------------------------------------------- /modules/dense_motion.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | import torch 4 | from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian 5 | 6 | class DenseMotionNetwork(nn.Module): 7 | """ 8 | Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving 9 | """ 10 | 11 | def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False, 12 | scale_factor=1, kp_variance=0.01): 13 | super(DenseMotionNetwork, self).__init__() 14 | self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1), 15 | max_features=max_features, num_blocks=num_blocks) 16 | 17 | self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3)) 18 | 19 | if estimate_occlusion_map: 20 | self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) 21 | else: 22 | self.occlusion = None 23 | 24 | self.num_kp = num_kp 25 | self.scale_factor = scale_factor 26 | self.kp_variance = kp_variance 27 | 28 | if self.scale_factor != 1: 29 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) 30 | 31 | def create_heatmap_representations(self, source_image, kp_driving, kp_source): 32 | """ 33 | Eq 6. in the paper H_k(z) 34 | """ 35 | spatial_size = source_image.shape[2:] 36 | gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance) 37 | gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance) 38 | heatmap = gaussian_driving - gaussian_source 39 | 40 | #adding background feature 41 | zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()) 42 | heatmap = torch.cat([zeros, heatmap], dim=1) 43 | heatmap = heatmap.unsqueeze(2) 44 | return heatmap 45 | 46 | def create_sparse_motions(self, source_image, kp_driving, kp_source): 47 | """ 48 | Eq 4. in the paper T_{s<-d}(z) 49 | """ 50 | bs, _, h, w = source_image.shape 51 | identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type()) 52 | identity_grid = identity_grid.view(1, 1, h, w, 2) 53 | coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2) 54 | if 'jacobian' in kp_driving: 55 | jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) 56 | jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) 57 | jacobian = jacobian.repeat(1, 1, h, w, 1, 1) 58 | coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) 59 | coordinate_grid = coordinate_grid.squeeze(-1) 60 | 61 | driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2) 62 | 63 | #adding background feature 64 | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) 65 | sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) 66 | return sparse_motions 67 | 68 | def create_deformed_source_image(self, source_image, sparse_motions): 69 | """ 70 | Eq 7. in the paper \hat{T}_{s<-d}(z) 71 | """ 72 | bs, _, h, w = source_image.shape 73 | source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1) 74 | source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w) 75 | sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1)) 76 | sparse_deformed = F.grid_sample(source_repeat, sparse_motions) 77 | # sparse_deformed = F.grid_sample(source_repeat, sparse_motions,align_corners = False) 78 | sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w)) 79 | return sparse_deformed 80 | 81 | def forward(self, source_image, kp_driving, kp_source): 82 | if self.scale_factor != 1: 83 | source_image = self.down(source_image) 84 | 85 | bs, _, h, w = source_image.shape 86 | 87 | out_dict = dict() 88 | heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)#bs*(numkp+1)*1*h*w 89 | sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)#bs*(numkp+1)*h*w*2 90 | deformed_source = self.create_deformed_source_image(source_image, sparse_motion) 91 | out_dict['sparse_deformed'] = deformed_source 92 | 93 | input = torch.cat([heatmap_representation, deformed_source], dim=2)#bs*num+1*4*w*h 94 | input = input.view(bs, -1, h, w) 95 | 96 | prediction = self.hourglass(input) 97 | 98 | mask = self.mask(prediction) 99 | mask = F.softmax(mask, dim=1) 100 | out_dict['mask'] = mask 101 | mask = mask.unsqueeze(2)#bs*numkp+1*1*h*w 102 | sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) 103 | deformation = (sparse_motion * mask).sum(dim=1)# bs,2,64,64 104 | deformation = deformation.permute(0, 2, 3, 1)#bs*h*w*2 105 | 106 | out_dict['deformation'] = deformation 107 | 108 | # Sec. 3.2 in the paper 109 | if self.occlusion: 110 | occlusion_map = torch.sigmoid(self.occlusion(prediction)) 111 | out_dict['occlusion_map'] = occlusion_map 112 | 113 | return out_dict 114 | -------------------------------------------------------------------------------- /modules/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d 5 | from modules.dense_motion import DenseMotionNetwork 6 | 7 | 8 | class OcclusionAwareGenerator(nn.Module): 9 | """ 10 | Generator that given source image and and keypoints try to transform image according to movement trajectories 11 | induced by keypoints. Generator follows Johnson architecture. 12 | """ 13 | 14 | def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks, 15 | num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): 16 | super(OcclusionAwareGenerator, self).__init__() 17 | 18 | if dense_motion_params is not None: 19 | self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels, 20 | estimate_occlusion_map=estimate_occlusion_map, 21 | **dense_motion_params) 22 | else: 23 | self.dense_motion_network = None 24 | 25 | self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) 26 | 27 | down_blocks = [] 28 | for i in range(num_down_blocks): 29 | in_features = min(max_features, block_expansion * (2 ** i)) 30 | out_features = min(max_features, block_expansion * (2 ** (i + 1))) 31 | down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 32 | self.down_blocks = nn.ModuleList(down_blocks) 33 | 34 | up_blocks = [] 35 | for i in range(num_down_blocks): 36 | in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i))) 37 | out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1))) 38 | up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) 39 | self.up_blocks = nn.ModuleList(up_blocks) 40 | 41 | self.bottleneck = torch.nn.Sequential() 42 | in_features = min(max_features, block_expansion * (2 ** num_down_blocks)) 43 | for i in range(num_bottleneck_blocks): 44 | self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) 45 | 46 | self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) 47 | self.estimate_occlusion_map = estimate_occlusion_map 48 | self.num_channels = num_channels 49 | 50 | def deform_input(self, inp, deformation): 51 | _, h_old, w_old, _ = deformation.shape 52 | _, _, h, w = inp.shape 53 | if h_old != h or w_old != w: 54 | deformation = deformation.permute(0, 3, 1, 2) 55 | deformation = F.interpolate(deformation, size=(h, w), mode='bilinear') 56 | deformation = deformation.permute(0, 2, 3, 1) 57 | return F.grid_sample(inp, deformation) 58 | # return F.grid_sample(inp, deformation,align_corners = False) 59 | 60 | def forward(self, source_image, kp_driving, kp_source): 61 | # Encoding (downsampling) part 62 | out = self.first(source_image) 63 | for i in range(len(self.down_blocks)): 64 | out = self.down_blocks[i](out) 65 | 66 | # Transforming feature representation according to deformation and occlusion 67 | output_dict = {} 68 | if self.dense_motion_network is not None: 69 | dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving, 70 | kp_source=kp_source) 71 | output_dict['mask'] = dense_motion['mask'] 72 | output_dict['sparse_deformed'] = dense_motion['sparse_deformed'] 73 | output_dict['deformation'] = dense_motion['deformation'] 74 | 75 | if 'occlusion_map' in dense_motion: 76 | occlusion_map = dense_motion['occlusion_map'] 77 | output_dict['occlusion_map'] = occlusion_map 78 | else: 79 | occlusion_map = None 80 | deformation = dense_motion['deformation'] 81 | out = self.deform_input(out, deformation) 82 | 83 | if occlusion_map is not None: 84 | if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: 85 | occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') 86 | out = out * occlusion_map 87 | 88 | output_dict["deformed"] = self.deform_input(source_image, deformation) 89 | 90 | # Decoding part 91 | out = self.bottleneck(out) 92 | for i in range(len(self.up_blocks)): 93 | out = self.up_blocks[i](out) 94 | out = self.final(out) 95 | out = F.sigmoid(out) 96 | 97 | output_dict["prediction"] = out 98 | 99 | return output_dict 100 | -------------------------------------------------------------------------------- /modules/keypoint_detector.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import torch.nn.functional as F 4 | from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d 5 | 6 | 7 | class KPDetector(nn.Module): 8 | """ 9 | Detecting a keypoints. Return keypoint position and jacobian near each keypoint. 10 | """ 11 | 12 | def __init__(self, block_expansion, num_kp, num_channels, max_features, 13 | num_blocks, temperature, estimate_jacobian=False, scale_factor=1, 14 | single_jacobian_map=False, pad=0): 15 | super(KPDetector, self).__init__() 16 | 17 | self.predictor = Hourglass(block_expansion, in_features=num_channels, 18 | max_features=max_features, num_blocks=num_blocks) 19 | 20 | self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7), 21 | padding=pad) 22 | 23 | if estimate_jacobian: 24 | self.num_jacobian_maps = 1 if single_jacobian_map else num_kp 25 | self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, 26 | out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad) 27 | self.jacobian.weight.data.zero_() 28 | self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) 29 | else: 30 | self.jacobian = None 31 | 32 | self.temperature = temperature 33 | self.scale_factor = scale_factor 34 | if self.scale_factor != 1: 35 | self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) 36 | 37 | def gaussian2kp(self, heatmap): 38 | """ 39 | Extract the mean and from a heatmap 40 | """ 41 | shape = heatmap.shape 42 | heatmap = heatmap.unsqueeze(-1) 43 | grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) 44 | value = (heatmap * grid).sum(dim=(2, 3)) 45 | kp = {'value': value} 46 | 47 | return kp 48 | 49 | def forward(self, x): 50 | if self.scale_factor != 1: 51 | x = self.down(x) 52 | 53 | feature_map = self.predictor(x) 54 | prediction = self.kp(feature_map) 55 | 56 | final_shape = prediction.shape 57 | heatmap = prediction.view(final_shape[0], final_shape[1], -1) 58 | heatmap = F.softmax(heatmap / self.temperature, dim=2) 59 | heatmap = heatmap.view(*final_shape) 60 | 61 | out = self.gaussian2kp(heatmap) 62 | 63 | if self.jacobian is not None: 64 | jacobian_map = self.jacobian(feature_map) 65 | 66 | jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], 67 | final_shape[3]) 68 | out["jacobian_map"] = jacobian_map 69 | heatmap = heatmap.unsqueeze(2) 70 | 71 | jacobian = heatmap * jacobian_map 72 | jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) 73 | jacobian = jacobian.sum(dim=-1) 74 | jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) 75 | out['jacobian'] = jacobian 76 | out["pred_fature"] = prediction 77 | return out 78 | -------------------------------------------------------------------------------- /modules/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 5 | """3x3 convolution with padding""" 6 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 7 | padding=dilation, groups=groups, bias=False, dilation=dilation) 8 | 9 | 10 | def conv1x1(in_planes, out_planes, stride=1): 11 | """1x1 convolution""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 18 | base_width=64, dilation=1, norm_layer=None): 19 | super(BasicBlock, self).__init__() 20 | if norm_layer is None: 21 | norm_layer = nn.BatchNorm2d 22 | if groups != 1 or base_width != 64: 23 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 24 | if dilation > 1: 25 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 26 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = norm_layer(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = norm_layer(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | identity = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | identity = self.downsample(x) 47 | 48 | out += identity 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 58 | base_width=64, dilation=1, norm_layer=None): 59 | super(Bottleneck, self).__init__() 60 | if norm_layer is None: 61 | norm_layer = nn.BatchNorm2d 62 | width = int(planes * (base_width / 64.)) * groups 63 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 64 | self.conv1 = conv1x1(inplanes, width) 65 | self.bn1 = norm_layer(width) 66 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 67 | self.bn2 = norm_layer(width) 68 | self.conv3 = conv1x1(width, planes * self.expansion) 69 | self.bn3 = norm_layer(planes * self.expansion) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | identity = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | identity = self.downsample(x) 90 | 91 | out += identity 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 99 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 100 | norm_layer=None,input_channel = 3): 101 | super(ResNet, self).__init__() 102 | if norm_layer is None: 103 | norm_layer = nn.BatchNorm2d 104 | self._norm_layer = norm_layer 105 | 106 | self.inplanes = 64 107 | self.dilation = 1 108 | if replace_stride_with_dilation is None: 109 | # each element in the tuple indicates if we should replace 110 | # the 2x2 stride with a dilated convolution instead 111 | replace_stride_with_dilation = [False, False, False] 112 | if len(replace_stride_with_dilation) != 3: 113 | raise ValueError("replace_stride_with_dilation should be None " 114 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 115 | self.groups = groups 116 | self.base_width = width_per_group 117 | self.conv1 = nn.Conv2d(input_channel, self.inplanes, kernel_size=7, stride=2, padding=3, 118 | bias=False) 119 | self.bn1 = norm_layer(self.inplanes) 120 | self.relu = nn.ReLU(inplace=True) 121 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 122 | self.layer1 = self._make_layer(block, 64, layers[0]) 123 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 124 | dilate=replace_stride_with_dilation[0]) 125 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 126 | dilate=replace_stride_with_dilation[1]) 127 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 128 | dilate=replace_stride_with_dilation[2]) 129 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 130 | self.fc = nn.Linear(512 * block.expansion, num_classes) 131 | 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 135 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 136 | nn.init.constant_(m.weight, 1) 137 | nn.init.constant_(m.bias, 0) 138 | 139 | # Zero-initialize the last BN in each residual branch, 140 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 141 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 142 | if zero_init_residual: 143 | for m in self.modules(): 144 | if isinstance(m, Bottleneck): 145 | nn.init.constant_(m.bn3.weight, 0) 146 | elif isinstance(m, BasicBlock): 147 | nn.init.constant_(m.bn2.weight, 0) 148 | 149 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 150 | norm_layer = self._norm_layer 151 | downsample = None 152 | previous_dilation = self.dilation 153 | if dilate: 154 | self.dilation *= stride 155 | stride = 1 156 | if stride != 1 or self.inplanes != planes * block.expansion: 157 | downsample = nn.Sequential( 158 | conv1x1(self.inplanes, planes * block.expansion, stride), 159 | norm_layer(planes * block.expansion), 160 | ) 161 | 162 | layers = [] 163 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 164 | self.base_width, previous_dilation, norm_layer)) 165 | self.inplanes = planes * block.expansion 166 | for _ in range(1, blocks): 167 | layers.append(block(self.inplanes, planes, groups=self.groups, 168 | base_width=self.base_width, dilation=self.dilation, 169 | norm_layer=norm_layer)) 170 | 171 | return nn.Sequential(*layers) 172 | 173 | def forward(self, x): 174 | x = self.conv1(x) 175 | x = self.bn1(x) 176 | x = self.relu(x) 177 | x = self.maxpool(x) 178 | 179 | x = self.layer1(x) 180 | x = self.layer2(x) 181 | x = self.layer3(x) 182 | x = self.layer4(x) 183 | 184 | x = self.avgpool(x) 185 | x = torch.flatten(x, 1) 186 | x = self.fc(x) 187 | 188 | return x 189 | 190 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 191 | model = ResNet(block, layers, **kwargs) 192 | return model 193 | 194 | def resnet34(pretrained=False, progress=True, **kwargs): 195 | r"""ResNet-34 model from 196 | `"Deep Residual Learning for Image Recognition" `_ 197 | 198 | Args: 199 | pretrained (bool): If True, returns a model pre-trained on ImageNet 200 | progress (bool): If True, displays a progress bar of the download to stderr 201 | """ 202 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 203 | **kwargs) -------------------------------------------------------------------------------- /modules/util.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d 7 | from sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d 8 | from modules.resnet import resnet34 9 | 10 | def gaussian2kp(heatmap): 11 | """ 12 | Extract the mean and from a heatmap 13 | """ 14 | shape = heatmap.shape 15 | heatmap = heatmap.unsqueeze(-1) 16 | grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) 17 | value = (heatmap * grid).sum(dim=(2, 3)) 18 | kp = {'value': value} 19 | 20 | return kp 21 | 22 | def kp2gaussian(kp, spatial_size, kp_variance): 23 | """ 24 | Transform a keypoint into gaussian like representation 25 | """ 26 | mean = kp['value'] #bs*numkp*2 27 | 28 | coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) #h*w*2 29 | number_of_leading_dimensions = len(mean.shape) - 1 30 | shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape #1*1*h*w*2 31 | coordinate_grid = coordinate_grid.view(*shape) 32 | repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1) 33 | coordinate_grid = coordinate_grid.repeat(*repeats) #bs*numkp*h*w*2 34 | 35 | # Preprocess kp shape 36 | shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2) 37 | mean = mean.view(*shape) 38 | 39 | mean_sub = (coordinate_grid - mean) 40 | 41 | out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) 42 | 43 | return out 44 | 45 | 46 | def make_coordinate_grid(spatial_size, type): 47 | """ 48 | Create a meshgrid [-1,1] x [-1,1] of given spatial_size. 49 | """ 50 | h, w = spatial_size 51 | x = torch.arange(w).type(type) 52 | y = torch.arange(h).type(type) 53 | 54 | x = (2 * (x / (w - 1)) - 1) 55 | y = (2 * (y / (h - 1)) - 1) 56 | 57 | yy = y.view(-1, 1).repeat(1, w) 58 | xx = x.view(1, -1).repeat(h, 1) 59 | 60 | meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) 61 | 62 | return meshed 63 | 64 | 65 | class ResBlock2d(nn.Module): 66 | """ 67 | Res block, preserve spatial resolution. 68 | """ 69 | 70 | def __init__(self, in_features, kernel_size, padding): 71 | super(ResBlock2d, self).__init__() 72 | self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 73 | padding=padding) 74 | self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 75 | padding=padding) 76 | self.norm1 = BatchNorm2d(in_features, affine=True) 77 | self.norm2 = BatchNorm2d(in_features, affine=True) 78 | 79 | def forward(self, x): 80 | out = self.norm1(x) 81 | out = F.relu(out,inplace=True) 82 | out = self.conv1(out) 83 | out = self.norm2(out) 84 | out = F.relu(out,inplace=True) 85 | out = self.conv2(out) 86 | out += x 87 | return out 88 | 89 | class ResBlock3d(nn.Module): 90 | """ 91 | Res block, preserve spatial resolution. 92 | """ 93 | 94 | def __init__(self, in_features, kernel_size, padding): 95 | super(ResBlock3d, self).__init__() 96 | self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 97 | padding=padding) 98 | self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, 99 | padding=padding) 100 | self.norm1 = BatchNorm3d(in_features, affine=True) 101 | self.norm2 = BatchNorm3d(in_features, affine=True) 102 | 103 | def forward(self, x): 104 | out = self.norm1(x) 105 | out = F.relu(out,inplace=True) 106 | out = self.conv1(out) 107 | out = self.norm2(out) 108 | out = F.relu(out,inplace=True) 109 | out = self.conv2(out) 110 | out += x 111 | return out 112 | 113 | 114 | class UpBlock2d(nn.Module): 115 | """ 116 | Upsampling block for use in decoder. 117 | """ 118 | 119 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 120 | super(UpBlock2d, self).__init__() 121 | 122 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 123 | padding=padding, groups=groups) 124 | self.norm = BatchNorm2d(out_features, affine=True) 125 | 126 | def forward(self, x): 127 | out = F.interpolate(x, scale_factor=2) 128 | del x 129 | out = self.conv(out) 130 | out = self.norm(out) 131 | out = F.relu(out,inplace=True) 132 | return out 133 | 134 | class UpBlock3d(nn.Module): 135 | """ 136 | Upsampling block for use in decoder. 137 | """ 138 | 139 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 140 | super(UpBlock3d, self).__init__() 141 | 142 | self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 143 | padding=padding, groups=groups) 144 | self.norm = BatchNorm3d(out_features, affine=True) 145 | self.res = ResBlock3d(out_features,kernel_size,padding) 146 | self.norm2 = BatchNorm3d(out_features,affine=True) 147 | 148 | def forward(self, x): 149 | out = F.interpolate(x, scale_factor=2) 150 | out = self.conv(out) 151 | out = self.norm(out) 152 | out = F.relu(out,inplace=True) 153 | out = self.res(out) 154 | out = self.norm2(out) 155 | out = F.relu(out,inplace=True) 156 | return out 157 | 158 | class DownBlock2d(nn.Module): 159 | """ 160 | Downsampling block for use in encoder. 161 | """ 162 | 163 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 164 | super(DownBlock2d, self).__init__() 165 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 166 | padding=padding, groups=groups) 167 | self.norm = BatchNorm2d(out_features, affine=True) 168 | self.pool = nn.AvgPool2d(kernel_size=(2, 2)) 169 | 170 | def forward(self, x): 171 | out = self.conv(x) 172 | del x 173 | out = self.norm(out) 174 | out = F.relu(out,inplace=True) 175 | out = self.pool(out) 176 | return out 177 | 178 | class DownBlock3d(nn.Module): 179 | """ 180 | Downsampling block for use in encoder. 181 | """ 182 | 183 | def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): 184 | super(DownBlock3d, self).__init__() 185 | 186 | self.res = ResBlock3d(in_features=in_features,kernel_size=kernel_size,padding=padding) 187 | self.norm_res = BatchNorm3d(in_features,affine=True) 188 | self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, 189 | padding=padding, groups=groups) 190 | 191 | self.norm = BatchNorm3d(out_features, affine=True) 192 | self.pool = nn.AvgPool3d(kernel_size=(2, 2, 2)) 193 | 194 | def forward(self, x): 195 | out = self.res(x) 196 | out = self.norm_res(out) 197 | out = F.relu(out,inplace=True) 198 | out = self.conv(out) 199 | out = self.norm(out) 200 | out = F.relu(out,inplace=True) 201 | out = self.pool(out) 202 | return out 203 | 204 | class SameBlock2d(nn.Module): 205 | """ 206 | Simple block, preserve spatial resolution. 207 | """ 208 | 209 | def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): 210 | super(SameBlock2d, self).__init__() 211 | self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, 212 | kernel_size=kernel_size, padding=padding, groups=groups) 213 | self.norm = BatchNorm2d(out_features, affine=True) 214 | 215 | def forward(self, x): 216 | out = self.conv(x) 217 | out = self.norm(out) 218 | out = F.relu(out,inplace=True) 219 | return out 220 | 221 | 222 | class Encoder(nn.Module): 223 | """ 224 | Hourglass Encoder 225 | """ 226 | 227 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 228 | super(Encoder, self).__init__() 229 | 230 | down_blocks = [] 231 | for i in range(num_blocks): 232 | down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), 233 | min(max_features, block_expansion * (2 ** (i + 1))), 234 | kernel_size=3, padding=1)) 235 | self.down_blocks = nn.ModuleList(down_blocks) 236 | 237 | def forward(self, x): 238 | outs = [x] 239 | for down_block in self.down_blocks: 240 | outs.append(down_block(outs[-1])) 241 | return outs 242 | 243 | 244 | class Encoder3D(nn.Module): 245 | """ 246 | Hourglass Encoder 247 | """ 248 | 249 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 250 | super(Encoder3D, self).__init__() 251 | 252 | down_blocks = [] 253 | for i in range(num_blocks): 254 | down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), 255 | min(max_features, block_expansion * (2 ** (i + 1))), 256 | kernel_size=3, padding=1)) 257 | self.down_blocks = nn.ModuleList(down_blocks) 258 | 259 | def forward(self, x): 260 | outs = [x] 261 | for down_block in self.down_blocks: 262 | outs.append(down_block(outs[-1])) 263 | return outs 264 | 265 | class Decoder(nn.Module): 266 | """ 267 | Hourglass Decoder 268 | """ 269 | 270 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 271 | super(Decoder, self).__init__() 272 | 273 | up_blocks = [] 274 | 275 | for i in range(num_blocks)[::-1]: 276 | in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) 277 | out_filters = min(max_features, block_expansion * (2 ** i)) 278 | up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) 279 | 280 | self.up_blocks = nn.ModuleList(up_blocks) 281 | self.out_filters = block_expansion + in_features 282 | 283 | def forward(self, x): 284 | out = x.pop() 285 | for up_block in self.up_blocks: 286 | out = up_block(out) 287 | skip = x.pop() 288 | out = torch.cat([out, skip], dim=1) 289 | return out 290 | 291 | 292 | class Decoder3D(nn.Module): 293 | """ 294 | Hourglass Decoder 295 | """ 296 | 297 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 298 | super(Decoder3D, self).__init__() 299 | 300 | up_blocks = [] 301 | res_blocks = [] 302 | 303 | for i in range(num_blocks)[::-1]: 304 | in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) 305 | out_filters = min(max_features, block_expansion * (2 ** i)) 306 | up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1)) 307 | if i>0: 308 | res_blocks.append(nn.Sequential(ResBlock3d(out_filters,kernel_size=3,padding=1),BatchNorm3d(out_filters), nn.ReLU(inplace=True))) 309 | else: 310 | res_blocks.append(nn.Sequential(ResBlock3d(in_features,kernel_size=3,padding=1),BatchNorm3d(in_features), nn.ReLU(inplace=True))) 311 | self.res_blocks = nn.ModuleList(res_blocks) 312 | self.up_blocks = nn.ModuleList(up_blocks) 313 | self.out_filters = block_expansion + in_features 314 | 315 | def forward(self, x): 316 | out = x.pop() 317 | for up_block,res_bl in zip(self.up_blocks,self.res_blocks): 318 | out = up_block(out) 319 | skip = x.pop() 320 | out = torch.cat([out, res_bl(skip)], dim=1) 321 | return out 322 | 323 | class Hourglass(nn.Module): 324 | """ 325 | Hourglass architecture. 326 | """ 327 | 328 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 329 | super(Hourglass, self).__init__() 330 | self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) 331 | self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) 332 | self.out_filters = self.decoder.out_filters 333 | 334 | def forward(self, x): 335 | return self.decoder(self.encoder(x)) 336 | 337 | class Hourglass3D(nn.Module): 338 | """ 339 | Hourglass architecture. 340 | """ 341 | 342 | def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): 343 | super(Hourglass3D, self).__init__() 344 | self.encoder = Encoder3D(block_expansion, in_features, num_blocks, max_features) 345 | self.decoder = Decoder3D(block_expansion, in_features, num_blocks, max_features) 346 | self.out_filters = self.decoder.out_filters 347 | 348 | def forward(self, x): 349 | return self.decoder(self.encoder(x)) 350 | 351 | 352 | class AntiAliasInterpolation2d(nn.Module): 353 | """ 354 | Band-limited downsampling, for better preservation of the input signal. 355 | """ 356 | def __init__(self, channels, scale): 357 | super(AntiAliasInterpolation2d, self).__init__() 358 | sigma = (1 / scale - 1) / 2 359 | kernel_size = 2 * round(sigma * 4) + 1 360 | self.ka = kernel_size // 2 361 | self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka 362 | 363 | 364 | kernel_size = [kernel_size, kernel_size] 365 | sigma = [sigma, sigma] 366 | # The gaussian kernel is the product of the 367 | # gaussian function of each dimension. 368 | kernel = 1 369 | meshgrids = torch.meshgrid( 370 | [ 371 | torch.arange(size, dtype=torch.float32) 372 | for size in kernel_size 373 | ] 374 | ) 375 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 376 | mean = (size - 1) / 2 377 | kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) 378 | 379 | # Make sure sum of values in gaussian kernel equals 1. 380 | kernel = kernel / torch.sum(kernel) 381 | # Reshape to depthwise convolutional weight 382 | kernel = kernel.view(1, 1, *kernel.size()) 383 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 384 | 385 | self.register_buffer('weight', kernel) 386 | self.groups = channels 387 | self.scale = scale 388 | 389 | def forward(self, input): 390 | if self.scale == 1.0: 391 | return input 392 | 393 | out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) 394 | out = F.conv2d(out, weight=self.weight, groups=self.groups) 395 | out = F.interpolate(out, scale_factor=(self.scale, self.scale)) 396 | 397 | return out 398 | 399 | 400 | 401 | class MyResNet34(nn.Module): 402 | def __init__(self,embedding_dim,input_channel = 3): 403 | super(MyResNet34, self).__init__() 404 | self.resnet = resnet34(norm_layer = BatchNorm2d,num_classes=embedding_dim,input_channel = input_channel) 405 | def forward(self, x): 406 | return self.resnet(x) 407 | 408 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image 2 | python_speech_features 3 | pyworld 4 | pyyaml 5 | pytorch-lightning 6 | imageio 7 | scipy 8 | pyworld 9 | opencv-python -------------------------------------------------------------------------------- /sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | 132 | .. math:: 133 | 134 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 135 | 136 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 137 | standard-deviation are reduced across all devices during training. 138 | 139 | For example, when one uses `nn.DataParallel` to wrap the network during 140 | training, PyTorch's implementation normalize the tensor on each device using 141 | the statistics only on that device, which accelerated the computation and 142 | is also easy to implement, but the statistics might be inaccurate. 143 | Instead, in this synchronized version, the statistics will be computed 144 | over all training samples distributed on multiple devices. 145 | 146 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 147 | as the built-in PyTorch implementation. 148 | 149 | The mean and standard-deviation are calculated per-dimension over 150 | the mini-batches and gamma and beta are learnable parameter vectors 151 | of size C (where C is the input size). 152 | 153 | During training, this layer keeps a running estimate of its computed mean 154 | and variance. The running sum is kept with a default momentum of 0.1. 155 | 156 | During evaluation, this running mean/variance is used for normalization. 157 | 158 | Because the BatchNorm is done over the `C` dimension, computing statistics 159 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 160 | 161 | Args: 162 | num_features: num_features from an expected input of size 163 | `batch_size x num_features [x width]` 164 | eps: a value added to the denominator for numerical stability. 165 | Default: 1e-5 166 | momentum: the value used for the running_mean and running_var 167 | computation. Default: 0.1 168 | affine: a boolean value that when set to ``True``, gives the layer learnable 169 | affine parameters. Default: ``True`` 170 | 171 | Shape: 172 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 173 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 174 | 175 | Examples: 176 | >>> # With Learnable Parameters 177 | >>> m = SynchronizedBatchNorm1d(100) 178 | >>> # Without Learnable Parameters 179 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 180 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 181 | >>> output = m(input) 182 | """ 183 | 184 | def _check_input_dim(self, input): 185 | if input.dim() != 2 and input.dim() != 3: 186 | raise ValueError('expected 2D or 3D input (got {}D input)' 187 | .format(input.dim())) 188 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 189 | 190 | 191 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 192 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 193 | of 3d inputs 194 | 195 | .. math:: 196 | 197 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 198 | 199 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 200 | standard-deviation are reduced across all devices during training. 201 | 202 | For example, when one uses `nn.DataParallel` to wrap the network during 203 | training, PyTorch's implementation normalize the tensor on each device using 204 | the statistics only on that device, which accelerated the computation and 205 | is also easy to implement, but the statistics might be inaccurate. 206 | Instead, in this synchronized version, the statistics will be computed 207 | over all training samples distributed on multiple devices. 208 | 209 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 210 | as the built-in PyTorch implementation. 211 | 212 | The mean and standard-deviation are calculated per-dimension over 213 | the mini-batches and gamma and beta are learnable parameter vectors 214 | of size C (where C is the input size). 215 | 216 | During training, this layer keeps a running estimate of its computed mean 217 | and variance. The running sum is kept with a default momentum of 0.1. 218 | 219 | During evaluation, this running mean/variance is used for normalization. 220 | 221 | Because the BatchNorm is done over the `C` dimension, computing statistics 222 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 223 | 224 | Args: 225 | num_features: num_features from an expected input of 226 | size batch_size x num_features x height x width 227 | eps: a value added to the denominator for numerical stability. 228 | Default: 1e-5 229 | momentum: the value used for the running_mean and running_var 230 | computation. Default: 0.1 231 | affine: a boolean value that when set to ``True``, gives the layer learnable 232 | affine parameters. Default: ``True`` 233 | 234 | Shape: 235 | - Input: :math:`(N, C, H, W)` 236 | - Output: :math:`(N, C, H, W)` (same shape as input) 237 | 238 | Examples: 239 | >>> # With Learnable Parameters 240 | >>> m = SynchronizedBatchNorm2d(100) 241 | >>> # Without Learnable Parameters 242 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 243 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 244 | >>> output = m(input) 245 | """ 246 | 247 | def _check_input_dim(self, input): 248 | if input.dim() != 4: 249 | raise ValueError('expected 4D input (got {}D input)' 250 | .format(input.dim())) 251 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 252 | 253 | 254 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 255 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 256 | of 4d inputs 257 | 258 | .. math:: 259 | 260 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 261 | 262 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 263 | standard-deviation are reduced across all devices during training. 264 | 265 | For example, when one uses `nn.DataParallel` to wrap the network during 266 | training, PyTorch's implementation normalize the tensor on each device using 267 | the statistics only on that device, which accelerated the computation and 268 | is also easy to implement, but the statistics might be inaccurate. 269 | Instead, in this synchronized version, the statistics will be computed 270 | over all training samples distributed on multiple devices. 271 | 272 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 273 | as the built-in PyTorch implementation. 274 | 275 | The mean and standard-deviation are calculated per-dimension over 276 | the mini-batches and gamma and beta are learnable parameter vectors 277 | of size C (where C is the input size). 278 | 279 | During training, this layer keeps a running estimate of its computed mean 280 | and variance. The running sum is kept with a default momentum of 0.1. 281 | 282 | During evaluation, this running mean/variance is used for normalization. 283 | 284 | Because the BatchNorm is done over the `C` dimension, computing statistics 285 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 286 | or Spatio-temporal BatchNorm 287 | 288 | Args: 289 | num_features: num_features from an expected input of 290 | size batch_size x num_features x depth x height x width 291 | eps: a value added to the denominator for numerical stability. 292 | Default: 1e-5 293 | momentum: the value used for the running_mean and running_var 294 | computation. Default: 0.1 295 | affine: a boolean value that when set to ``True``, gives the layer learnable 296 | affine parameters. Default: ``True`` 297 | 298 | Shape: 299 | - Input: :math:`(N, C, D, H, W)` 300 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 301 | 302 | Examples: 303 | >>> # With Learnable Parameters 304 | >>> m = SynchronizedBatchNorm3d(100) 305 | >>> # Without Learnable Parameters 306 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 307 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 308 | >>> output = m(input) 309 | """ 310 | 311 | def _check_input_dim(self, input): 312 | if input.dim() != 5: 313 | raise ValueError('expected 5D input (got {}D input)' 314 | .format(input.dim())) 315 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 316 | -------------------------------------------------------------------------------- /sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | --------------------------------------------------------------------------------