├── .gitattributes ├── 10_fake.png ├── 146_fake.png ├── Diagram.png ├── README.md ├── attribute_transfer_model.py ├── data_loader.py ├── expression_synthesis.py ├── sequence_landmark.py ├── shape_predictor_68_face_landmarks.dat ├── solver.py └── tensorboard_logger.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.mp4 filter=lfs diff=lfs merge=lfs -text 2 | *.dat filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /10_fake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CreativePapers/Learn-to-Synthesize-and-Synthesize-to-Learn/000f1bd2eda604e1db58d6b91898bdbbf3a2f996/10_fake.png -------------------------------------------------------------------------------- /146_fake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CreativePapers/Learn-to-Synthesize-and-Synthesize-to-Learn/000f1bd2eda604e1db58d6b91898bdbbf3a2f996/146_fake.png -------------------------------------------------------------------------------- /Diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CreativePapers/Learn-to-Synthesize-and-Synthesize-to-Learn/000f1bd2eda604e1db58d6b91898bdbbf3a2f996/Diagram.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learn to Synthesize and Synthesize to Learn 2 | This repository includes the training and testing codes for: "Learn to Synthesize and Synthesize to Learn". 3 | 4 | ![](https://github.com/CreativePapers/ECCV2018/blob/master/146_fake.png) 5 | 6 | ### Dependencies 7 | - [Python2.7](https://www.anaconda.com/download/#linux) 8 | - [PyTorch](http://pytorch.org/) 9 | - [torchvision](http://pytorch.org/docs/master/torchvision) 10 | - [OpenCV](https://opencv.org/) 11 | - [Dlib](http://dlib.net/) 12 | 13 | ### Datasets 14 | - [The Binghamton University 3D Facial Expression Database (BU-3DFE)](http://www.cs.binghamton.edu/~lijun/Research/3DFE/3DFE_Analysis.html) 15 | - [The Radboud Faces Database (RaFD)](http://www.socsci.ru.nl:8180/RaFD2/RaFD?p=main) 16 | - [The MUG dataset](https://mug.ee.auth.gr/fed/) 17 | - [Oulu-CASIA VIS](http://www.cse.oulu.fi/CMV/Downloads/Oulu-CASIA) 18 | 19 | ## Example Usage for expression synthesis model: 20 | 21 | ### Clone the repository 22 | ``` 23 | $ git clone https://github.com/CreativePapers/ECCV2018.git 24 | cd ECCV2018-master 25 | ``` 26 | ### Train 27 | ``` 28 | python expression_synthesis.py --mode='train' 29 | ``` 30 | ### Test 31 | ``` 32 | python expression_synthesis.py --mode='test' 33 | ``` 34 | ### Attribute-Guided Face Synthesis Model 35 | 36 | ![](https://github.com/CreativePapers/ECCV2018/blob/master/Diagram.png) 37 | 38 | 39 | 40 | 41 | ### Sample Results 42 | 43 | 44 | ![](https://github.com/CreativePapers/ECCV2018/blob/master/10_fake.png) 45 | 46 | ## TO DO 47 | To add pose normalization model 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /attribute_transfer_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def num_flat_features(x): 8 | size = x.size()[1:] 9 | num_features = 1 10 | for s in size: 11 | num_features *= s 12 | return num_features 13 | 14 | 15 | class _Residual_Block(nn.Module): 16 | """Residual Block.""" 17 | def __init__(self, dim_in, dim_out): 18 | super(_Residual_Block, self).__init__() 19 | self.main = nn.Sequential( 20 | nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False), 21 | nn.InstanceNorm2d(dim_out, affine=True), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False), 24 | nn.InstanceNorm2d(dim_out, affine=True)) 25 | 26 | def forward(self, x): 27 | return x + self.main(x) 28 | 29 | 30 | class Encoder(nn.Module): 31 | """ 32 | Generator (convolutional encoder) to encode input images (image+ landmark heatmap) 33 | """ 34 | def __init__(self, conv_dim=64, repeat_num=6): 35 | super(Encoder, self).__init__() 36 | 37 | net = [] 38 | net.append(nn.Conv2d(6, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) 39 | net.append(nn.InstanceNorm2d(conv_dim, affine=True)) 40 | net.append(nn.ReLU(inplace=True)) 41 | 42 | # Down sampling 43 | channel_dim = conv_dim 44 | for i in range(4): 45 | net.append(nn.Conv2d(channel_dim, channel_dim * 2, kernel_size=4, stride=2, padding=1, bias=False)) 46 | net.append(nn.InstanceNorm2d(channel_dim * 2, affine=True)) 47 | net.append(nn.ReLU(inplace=True)) 48 | channel_dim = channel_dim * 2 49 | 50 | # Residual blocks 51 | for i in range(repeat_num): 52 | net.append(_Residual_Block(dim_in=channel_dim, dim_out=channel_dim)) 53 | 54 | self.main = nn.Sequential(*net) 55 | 56 | 57 | def forward(self, x): 58 | return self.main(x) 59 | 60 | 61 | class Decoder(nn.Module): 62 | """ 63 | Generator (convolutional decoder) takes attributes and latent features and synthesize new images conditioned on the attributes of interest 64 | """ 65 | def __init__(self, conv_dim=64): 66 | super(Decoder, self).__init__() 67 | # seven attributes (expression) categories 68 | channel_dim = 16*conv_dim+7 69 | net = [] 70 | 71 | for i in range(4): 72 | 73 | up_scale=2 74 | net.append(nn.Conv2d(channel_dim, channel_dim//2 * up_scale ** 2, kernel_size=3, padding=1)) 75 | net.append(nn.PixelShuffle(up_scale)) 76 | 77 | net.append(nn.InstanceNorm2d(channel_dim//2, affine=True)) 78 | net.append(nn.ReLU(inplace=True)) 79 | channel_dim = channel_dim // 2 80 | 81 | 82 | self.main = nn.Sequential(*net) 83 | self.conv1 = nn.Conv2d(channel_dim, 3, kernel_size=7, stride=1, padding=3, bias=False) 84 | self.conv2 = nn.Conv2d(channel_dim, 3, kernel_size=7, stride=1, padding=3, bias=False) 85 | self.m=nn.Tanh() 86 | 87 | 88 | def forward(self, input, label): 89 | label = label.unsqueeze(2).unsqueeze(3) 90 | label = label.expand(label.size(0), label.size(1), input.size(2), input.size(3)) 91 | x = torch.cat([input, label], 1) 92 | h = self.main(x) 93 | h_1=self.conv1(h) 94 | h_2=self.conv2(h) 95 | out_image=self.m(h_1) 96 | out_landmark = self.m(h_2) 97 | return out_image, out_landmark 98 | 99 | 100 | class Discriminator(nn.Module): 101 | """Discriminator""" 102 | def __init__(self, image_size=128, first_dim=64, repeat_num=6): 103 | super(Discriminator, self).__init__() 104 | 105 | net = [] 106 | net.append(nn.Conv2d(6, first_dim, kernel_size=4, stride=2, padding=1)) 107 | net.append(nn.LeakyReLU(0.01, inplace=True)) 108 | 109 | channel_dim = first_dim 110 | for i in range(1, repeat_num): 111 | net.append(nn.Conv2d(channel_dim, channel_dim*2, kernel_size=4, stride=2, padding=1)) 112 | net.append(nn.LeakyReLU(0.01, inplace=True)) 113 | channel_dim = channel_dim * 2 114 | 115 | self.main = nn.Sequential(*net) 116 | self.conv1 = nn.Conv2d(channel_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) 117 | # there are seven expression classes 118 | self.fc = nn.Linear(channel_dim * 2 * 2, 7) 119 | 120 | def forward(self, x): 121 | h = self.main(x) 122 | out_real = self.conv1(h) 123 | h_cls = h.view(-1, num_flat_features(h)) 124 | return out_real.squeeze(), self.fc(h_cls) 125 | 126 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from glob import glob 4 | from torch.utils.data import Dataset 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from torchvision.datasets import ImageFolder 8 | from PIL import Image 9 | 10 | 11 | 12 | 13 | def load_data_list(data_dir): 14 | path = os.path.join(data_dir, '', '*') 15 | file_list = glob(path) 16 | return file_list 17 | 18 | class ConcatDataset(torch.utils.data.Dataset): 19 | def __init__(self, *datasets): 20 | self.datasets = datasets 21 | 22 | def __getitem__(self, i): 23 | return tuple(d[i] for d in self.datasets) 24 | 25 | def __len__(self): 26 | return min(len(d) for d in self.datasets) 27 | 28 | 29 | def return_loader(crop_size, image_size, batch_size, mode='train'): 30 | """Return data loader.""" 31 | 32 | if mode == 'train': 33 | transform = transforms.Compose([ 34 | transforms.CenterCrop(crop_size), 35 | transforms.Scale(image_size), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 39 | ) 40 | else: 41 | transform = transforms.Compose([ 42 | transforms.CenterCrop(crop_size), 43 | transforms.Scale(image_size), 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 46 | 47 | 48 | shuffle = False 49 | 50 | if mode == 'train': 51 | shuffle = True 52 | 53 | #Path to image folders of expression classes for both image and landmark heatmap 54 | traindir_img='/data/train_face_data/image/' 55 | traindir_heatmap='/data/train_face_data/landmark/' 56 | 57 | 58 | data_loader =DataLoader( 59 | dataset=ConcatDataset( 60 | ImageFolder(traindir_img,transform), 61 | ImageFolder(traindir_heatmap,transform) 62 | ), 63 | batch_size=batch_size, shuffle=shuffle) 64 | return data_loader 65 | -------------------------------------------------------------------------------- /expression_synthesis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from torch.backends import cudnn 4 | from solver import Solver 5 | from data_loader import return_loader 6 | 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | 12 | # Model parameters 13 | parser.add_argument('--y_dim', type=int, default=7) 14 | parser.add_argument('--face_crop_size', type=int, default=256) 15 | parser.add_argument('--im_size', type=int, default=128) 16 | parser.add_argument('--g_first_dim', type=int, default=64) 17 | parser.add_argument('--d_first_dim', type=int, default=64) 18 | parser.add_argument('--enc_repeat_num', type=int, default=6) 19 | parser.add_argument('--num_layers', type=int, default=3) 20 | parser.add_argument('--d_repeat_num', type=int, default=6) 21 | parser.add_argument('--lambda_cls', type=float, default=1) 22 | parser.add_argument('--lambda_id', type=float, default=10) 23 | parser.add_argument('--lambda_bi', type=float, default=10) 24 | parser.add_argument('--lambda_gp', type=float, default=10) 25 | parser.add_argument('--d_train_repeat', type=int, default=5) 26 | parser.add_argument('--enc_lr', type=float, default=0.0001) 27 | parser.add_argument('--dec_lr', type=float, default=0.0001) 28 | parser.add_argument('--d_lr', type=float, default=0.0001) 29 | 30 | # Training settings 31 | parser.add_argument('--batch_size', type=int, default=8) 32 | parser.add_argument('--beta1', type=float, default=0.5) 33 | parser.add_argument('--beta2', type=float, default=0.999) 34 | parser.add_argument('--num_epochs', type=int, default=300) 35 | parser.add_argument('--num_epochs_decay', type=int, default=100) 36 | parser.add_argument('--num_iters', type=int, default=160000) 37 | parser.add_argument('--num_iters_decay', type=int, default=60000) 38 | parser.add_argument('--trained_model', type=str, default='') 39 | 40 | # Test settings 41 | parser.add_argument('--test_model', type=str, default='') 42 | 43 | # Set mode (train or test) 44 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 45 | 46 | # Path to save models and logs 47 | parser.add_argument('--log_path', type=str, default='/main_folder/logs') 48 | parser.add_argument('--model_path', type=str, default='/main_folder/models') 49 | parser.add_argument('--sample_path', type=str, default='/main_folder/samples') 50 | parser.add_argument('--test_path', type=str, default='/main_folder/results') 51 | 52 | # Step size 53 | parser.add_argument('--log_step', type=int, default=10) 54 | parser.add_argument('--sample_step', type=int, default=150) 55 | parser.add_argument('--model_save_step', type=int, default=400) 56 | 57 | config = parser.parse_args() 58 | print(config) 59 | cudnn.benchmark = True 60 | 61 | # Create directories if not exist 62 | if not os.path.exists(config.log_path): 63 | os.makedirs(config.log_path) 64 | if not os.path.exists(config.model_path): 65 | os.makedirs(config.model_path) 66 | if not os.path.exists(config.sample_path): 67 | os.makedirs(config.sample_path) 68 | if not os.path.exists(config.test_path): 69 | os.makedirs(config.test_path) 70 | face_data_loader = return_loader(config.face_crop_size, 71 | config.im_size, config.batch_size, config.mode) 72 | # Solver 73 | solver = Solver(face_data_loader,config) 74 | 75 | if config.mode == 'train': 76 | solver.train() 77 | 78 | elif config.mode == 'test': 79 | solver.test() 80 | 81 | -------------------------------------------------------------------------------- /sequence_landmark.py: -------------------------------------------------------------------------------- 1 | from imutils import face_utils 2 | import matplotlib.pyplot as plt 3 | from PIL import Image 4 | import numpy as np 5 | import imutils 6 | import time 7 | import dlib 8 | import cv2 9 | import os 10 | 11 | def list_files(path): 12 | # returns a list of names (with extension, without full path) of all files 13 | # in folder path 14 | files = [] 15 | for name in os.listdir(path): 16 | if os.path.isfile(os.path.join(path, name)): 17 | files.append(name) 18 | return files 19 | 20 | #2D Gaussian function 21 | def twoD_Gaussian((x, y), xo, yo, sigma_x, sigma_y): 22 | a = 1./(2*sigma_x**2) + 1./(2*sigma_y**2) 23 | c = 1./(2*sigma_x**2) + 1./(2*sigma_y**2) 24 | g = np.exp( - (a*((x-xo)**2) + c*((y-yo)**2))) 25 | return g.ravel() 26 | 27 | # Store the shape_predictors path 28 | predictor_path = os.path.abspath("/main_folder/shape_predictor_68_face_landmarks.dat") 29 | output_path='/main_folder/landmark_results' 30 | 31 | # Display Width 32 | display_width = 1000 33 | scale= 1 34 | downscaled_width = display_width / 2 35 | 36 | # Load DLibs face detector 37 | print("[INFO] loading facial landmark predictor...") 38 | detector = dlib.get_frontal_face_detector() 39 | 40 | # Load the shape_predictor 41 | predictor = dlib.shape_predictor(predictor_path) 42 | 43 | # Reading image sequence 44 | sequence_path='/data/face_image_sequence' 45 | list_imgs = list_files(sequence_path) 46 | 47 | for i, im in enumerate(list_imgs): 48 | infile = os.path.join(sequence_path, im) 49 | RGB_frame = cv2.imread(infile) 50 | frame = cv2.imread(infile, cv2.IMREAD_GRAYSCALE) 51 | height, width = frame.shape[:2] 52 | s_height, s_width = height // scale, width // scale 53 | imgDim = s_height 54 | img = cv2.resize(frame, (s_width, s_height)) 55 | 56 | # Detect the faces 57 | rects = detector(img, 0) 58 | 59 | # For each detected face 60 | for rect in rects: 61 | # Apply the facial landmarks to the region 62 | shape = predictor(img, rect) 63 | shape = face_utils.shape_to_np(shape) 64 | npLandmarks = np.float32(shape) 65 | shape = np.delete(shape, [60, 64], axis=0) 66 | shape = shape[17:] 67 | N = shape.shape[0] 68 | # Convert dlib coordinates to OpenCV coordinates 69 | (x, y, w, h) = face_utils.rect_to_bb(rect) 70 | faceOrig = RGB_frame[y:y + h, x:x + w] 71 | arr = np.zeros((h, w), np.float) 72 | arr_2 = np.zeros((height, width), np.float) 73 | yy, xx = np.mgrid[y:h + y, x:w + x] 74 | cv2.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2) 75 | 76 | # Point size is proportional to face size 77 | radius = (dlib.rectangle.height(rect) * scale / 90) 78 | 79 | # For every landmark coordinate 80 | for (xs, ys) in shape: 81 | # Draw the points onto the frame 82 | cv2.circle(frame, (xs, ys), radius, (0, 0, 255), -1) 83 | Gauss = twoD_Gaussian((xx, yy), xs, ys, 10, 10) 84 | map = Gauss.reshape(xx.shape[0], yy.shape[1]) 85 | arr = arr + map / N 86 | 87 | minval = arr.min() 88 | maxval = arr.max() 89 | arr -= minval 90 | arr *= (255.0 / (maxval - minval)) 91 | arr_2[y:h + y, x:w + x] = arr 92 | map_result = Image.fromarray((arr_2).astype(np.uint8)) 93 | out_path = output_path + str(infile)[-13:-4] + '.png' 94 | map_result.save(out_path) 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /shape_predictor_68_face_landmarks.dat: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f 3 | size 99693937 4 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import grad 5 | from torch.autograd import Variable 6 | from torchvision.utils import save_image 7 | from torchvision import transforms 8 | import numpy as np 9 | import os 10 | import time 11 | import datetime 12 | from attribute_transfer_model import Discriminator 13 | from attribute_transfer_model import Encoder 14 | from attribute_transfer_model import Decoder 15 | from PIL import Image 16 | 17 | 18 | 19 | 20 | class Solver(object): 21 | 22 | def __init__(self, face_data_loader, config): 23 | # Data loader 24 | self.face_data_loader = face_data_loader 25 | 26 | # Model parameters 27 | self.y_dim = config.y_dim 28 | self.num_layers=config.num_layers 29 | self.im_size = config.im_size 30 | self.g_first_dim = config.g_first_dim 31 | self.d_first_dim = config.d_first_dim 32 | self.enc_repeat_num = config.enc_repeat_num 33 | self.d_repeat_num = config.d_repeat_num 34 | self.d_train_repeat = config.d_train_repeat 35 | 36 | # Hyper-parameteres 37 | self.lambda_cls = config.lambda_cls 38 | self.lambda_id = config.lambda_id 39 | self.lambda_bi = config.lambda_bi 40 | self.lambda_gp = config.lambda_gp 41 | self.enc_lr = config.enc_lr 42 | self.dec_lr = config.dec_lr 43 | self.d_lr = config.d_lr 44 | self.beta1 = config.beta1 45 | self.beta2 = config.beta2 46 | 47 | # Training settings 48 | self.num_epochs = config.num_epochs 49 | self.num_epochs_decay = config.num_epochs_decay 50 | self.num_iters = config.num_iters 51 | self.num_iters_decay = config.num_iters_decay 52 | self.batch_size = config.batch_size 53 | self.trained_model = config.trained_model 54 | 55 | # Test settings 56 | self.test_model = config.test_model 57 | 58 | # Path 59 | self.log_path = config.log_path 60 | self.sample_path = config.sample_path 61 | self.model_path = config.model_path 62 | self.test_path = config.test_path 63 | 64 | # Step size 65 | self.log_step = config.log_step 66 | self.sample_step = config.sample_step 67 | self.model_save_step = config.model_save_step 68 | 69 | # Set tensorboard 70 | self.build_model() 71 | self.use_tensorboard() 72 | 73 | # Start with trained model 74 | if self.trained_model: 75 | self.load_trained_model() 76 | 77 | def build_model(self): 78 | # Define encoder-decoder (generator) and a discriminator 79 | self.Enc = Encoder(self.g_first_dim, self.enc_repeat_num) 80 | self.Dec = Decoder(self.g_first_dim) 81 | self.D = Discriminator(self.im_size, self.d_first_dim, self.d_repeat_num) 82 | 83 | # Optimizers 84 | self.enc_optimizer = torch.optim.Adam(self.Enc.parameters(), self.enc_lr, [self.beta1, self.beta2]) 85 | self.dec_optimizer = torch.optim.Adam(self.Dec.parameters(), self.dec_lr, [self.beta1, self.beta2]) 86 | self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) 87 | 88 | if torch.cuda.is_available(): 89 | self.Enc.cuda() 90 | self.Dec.cuda() 91 | self.D.cuda() 92 | 93 | def load_trained_model(self): 94 | 95 | self.Enc.load_state_dict(torch.load(os.path.join( 96 | self.model_path, '{}_Enc.pth'.format(self.trained_model)))) 97 | self.Dec.load_state_dict(torch.load(os.path.join( 98 | self.model_path, '{}_Dec.pth'.format(self.trained_model)))) 99 | self.D.load_state_dict(torch.load(os.path.join( 100 | self.model_path, '{}_D.pth'.format(self.trained_model)))) 101 | print('loaded models (step: {})..!'.format(self.trained_model)) 102 | 103 | def use_tensorboard(self): 104 | from tensorboard_logger import Logger 105 | self.logger = Logger(self.log_path) 106 | 107 | def update_lr(self, enc_lr,dec_lr, d_lr): 108 | for param_group in self.enc_optimizer.param_groups: 109 | param_group['lr'] = enc_lr 110 | for param_group in self.dec_optimizer.param_groups: 111 | param_group['lr'] = dec_lr 112 | for param_group in self.d_optimizer.param_groups: 113 | param_group['lr'] = d_lr 114 | 115 | def reset(self): 116 | self.enc_optimizer.zero_grad() 117 | self.dec_optimizer.zero_grad() 118 | self.d_optimizer.zero_grad() 119 | 120 | def to_var(self, x, volatile=False): 121 | if torch.cuda.is_available(): 122 | x = x.cuda() 123 | return Variable(x, volatile=volatile) 124 | 125 | def calculate_accuracy(self, x, y): 126 | _, predicted = torch.max(x, dim=1) 127 | correct = (predicted == y).float() 128 | accuracy = torch.mean(correct) * 100.0 129 | return accuracy 130 | 131 | def denorm(self, x): 132 | out = (x + 1) / 2 133 | return out.clamp_(0, 1) 134 | 135 | def one_hot(self, labels, dim): 136 | """Convert label indices to one-hot vector""" 137 | batch_size = labels.size(0) 138 | out = torch.zeros(batch_size, dim) 139 | out[np.arange(batch_size), labels.long()] = 1 140 | return out 141 | 142 | 143 | def train(self): 144 | """Train attribute-guided face image synthesis model""" 145 | self.data_loader = self.face_data_loader 146 | # The number of iterations for each epoch 147 | iters_per_epoch = len(self.data_loader) 148 | 149 | sample_x = [] 150 | sample_l=[] 151 | real_y = [] 152 | for i, (images, landmark) in enumerate(self.data_loader): 153 | labels=images[1] 154 | sample_x.append(images[0]) 155 | sample_l.append(landmark[0]) 156 | real_y.append(labels) 157 | if i == 2: 158 | break 159 | 160 | # Sample inputs and desired domain labels for testing 161 | sample_x = torch.cat(sample_x, dim=0) 162 | sample_x = self.to_var(sample_x, volatile=True) 163 | sample_l = torch.cat(sample_l, dim=0) 164 | sample_l = self.to_var(sample_l, volatile=True) 165 | real_y = torch.cat(real_y, dim=0) 166 | 167 | sample_y_list = [] 168 | for i in range(self.y_dim): 169 | sample_y = self.one_hot(torch.ones(sample_x.size(0)) * i, self.y_dim) 170 | sample_y_list.append(self.to_var(sample_y, volatile=True)) 171 | 172 | # Learning rate for decaying 173 | d_lr = self.d_lr 174 | enc_lr=self.enc_lr 175 | dec_lr=self.dec_lr 176 | 177 | # Start with trained model 178 | if self.trained_model: 179 | start = int(self.trained_model.split('_')[0]) 180 | else: 181 | start = 0 182 | 183 | # Start training 184 | start_time = time.time() 185 | for e in range(start, self.num_epochs): 186 | for i, (real_image, real_landmark) in enumerate(self.data_loader): 187 | #real_x: real image and real_l: conditional side image (landmark heatmap) 188 | real_x=real_image[0] 189 | real_label = real_image[1] 190 | real_l=real_landmark[0] 191 | 192 | # Sample fake labels randomly 193 | rand_idx = torch.randperm(real_label.size(0)) 194 | fake_label = real_label[rand_idx] 195 | 196 | real_y = self.one_hot(real_label, self.y_dim) 197 | fake_y = self.one_hot(fake_label, self.y_dim) 198 | 199 | # Convert tensor to variable 200 | real_x = self.to_var(real_x) 201 | real_l = self.to_var(real_l) 202 | real_y = self.to_var(real_y) 203 | fake_y = self.to_var(fake_y) 204 | real_label = self.to_var(real_label) 205 | fake_label = self.to_var(fake_label) 206 | 207 | #================== Train Discriminator ================== # 208 | # Input images (original image+side images) are concatenated 209 | src_output, cls_output = self.D(torch.cat([real_x, real_l], 1)) 210 | d_loss_real = - torch.mean(src_output) 211 | d_loss_cls = F.cross_entropy(cls_output, real_label) 212 | 213 | # Compute expression recognition accuracy on synthetic images 214 | if (i+1) % self.log_step == 0: 215 | accuracies = self.calculate_accuracy(cls_output, real_label) 216 | log = ["{:.2f}".format(acc) for acc in accuracies.data.cpu().numpy()] 217 | print('Recognition Acc: ') 218 | print(log) 219 | 220 | # Generate outputs and compute loss with fake generated images 221 | enc_feat = self.Enc(torch.cat([real_x, real_l], 1)) 222 | fake_x, fake_l= self.Dec(enc_feat, fake_y) 223 | fake_x = Variable(fake_x.data) 224 | fake_l = Variable(fake_l.data) 225 | 226 | src_output, cls_output = self.D(torch.cat([fake_x, fake_l], 1)) 227 | d_loss_fake = torch.mean(src_output) 228 | 229 | # Discriminator losses 230 | d_loss = self.lambda_cls * d_loss_cls+d_loss_real + d_loss_fake 231 | self.reset() 232 | d_loss.backward() 233 | self.d_optimizer.step() 234 | 235 | # Compute gradient penalty loss 236 | real=torch.cat([real_x, real_l], 1) 237 | fake=torch.cat([fake_x, fake_l], 1) 238 | alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real) 239 | interpolated = Variable(alpha * real.data + (1 - alpha) * fake.data, requires_grad=True) 240 | output, cls_output = self.D(interpolated) 241 | 242 | grad = torch.autograd.grad(outputs=output, 243 | inputs=interpolated, 244 | grad_outputs=torch.ones(output.size()).cuda(), 245 | retain_graph=True, 246 | create_graph=True, 247 | only_inputs=True)[0] 248 | 249 | grad = grad.view(grad.size(0), -1) 250 | grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) 251 | d_loss_gp = torch.mean((grad_l2norm - 1)**2) 252 | 253 | # Gradient penalty loss 254 | d_loss = self.lambda_gp * d_loss_gp 255 | self.reset() 256 | d_loss.backward() 257 | self.d_optimizer.step() 258 | 259 | # Logging 260 | loss = {} 261 | loss['D/loss_real'] = d_loss_real.data[0] 262 | loss['D/loss_fake'] = d_loss_fake.data[0] 263 | loss['D/loss_cls'] = d_loss_cls.data[0] 264 | loss['D/loss_gp'] = d_loss_gp.data[0] 265 | 266 | # ================== Train Encoder-Decoder networks ================== # 267 | if (i+1) % self.d_train_repeat == 0: 268 | 269 | # Original-to-target and target-to-original domain 270 | enc_feat = self.Enc(torch.cat([real_x, real_l], 1)) 271 | fake_x, fake_l = self.Dec(enc_feat, fake_y) 272 | src_output, cls_output=self.D(torch.cat([fake_x, fake_l], 1)) 273 | g_loss_fake = - torch.mean(src_output) 274 | 275 | #rec_feat = self.Enc(fake_x) 276 | rec_feat = self.Enc(torch.cat([fake_x, fake_l], 1)) 277 | rec_x,rec_l=self.Dec(rec_feat, real_y) 278 | 279 | # bidirectional loss of the images 280 | g_loss_rec_x = torch.mean(torch.abs(real_x - rec_x)) 281 | g_loss_rec_l=torch.mean(torch.abs(real_l-rec_l)) 282 | 283 | #bidirectional loss of the latent feature 284 | g_loss_feature = torch.mean(torch.abs(enc_feat - rec_feat)) 285 | 286 | #identity loss of the images 287 | g_loss_identity_x = torch.mean(torch.abs(real_x - fake_x)) 288 | g_loss_identity_l = torch.mean(torch.abs(real_l - fake_l)) 289 | 290 | # attribute classification loss for the fake generated images 291 | g_loss_cls = F.cross_entropy(cls_output, fake_label) 292 | 293 | # Backward + Optimize (generator (encoder-decoder) losses), we update decoder two times for each encoder update 294 | g_loss = g_loss_fake +self.lambda_bi * g_loss_rec_x +self.lambda_bi * g_loss_rec_l +self.lambda_bi * g_loss_feature+self.lambda_id * g_loss_identity_x+self.lambda_id * g_loss_identity_l+self.lambda_cls * g_loss_cls 295 | self.reset() 296 | g_loss.backward() 297 | self.enc_optimizer.step() 298 | self.dec_optimizer.step() 299 | self.dec_optimizer.step() 300 | 301 | # Logging Generator losses 302 | loss['G/loss_feature'] = g_loss_feature.data[0] 303 | loss['G/loss_identity_x'] = g_loss_identity_x.data[0] 304 | loss['G/loss_identity_l'] = g_loss_identity_l.data[0] 305 | loss['G/loss_rec_x'] = g_loss_rec_x.data[0] 306 | loss['G/loss_rec_l'] = g_loss_rec_l.data[0] 307 | loss['G/loss_fake'] = g_loss_fake.data[0] 308 | loss['G/loss_cls'] = g_loss_cls.data[0] 309 | 310 | # Print out log 311 | if (i+1) % self.log_step == 0: 312 | elapsed = time.time() - start_time 313 | elapsed = str(datetime.timedelta(seconds=elapsed)) 314 | 315 | log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( 316 | elapsed, e+1, self.num_epochs, i+1, iters_per_epoch) 317 | 318 | for tag, value in loss.items(): 319 | log += ", {}: {:.4f}".format(tag, value) 320 | print(log) 321 | 322 | 323 | for tag, value in loss.items(): 324 | self.logger.scalar_summary(tag, value, e * iters_per_epoch + i + 1) 325 | 326 | # Synthesize images 327 | if (i+1) % self.sample_step == 0: 328 | fake_image_list = [sample_x] 329 | for sample_y in sample_y_list: 330 | enc_feat = self.Enc(torch.cat([sample_x, sample_l], 1)) 331 | sample_result,sample_landmark = self.Dec(enc_feat, sample_y) 332 | fake_image_list.append(sample_result) 333 | fake_images = torch.cat(fake_image_list, dim=3) 334 | save_image(self.denorm(fake_images.data), 335 | os.path.join(self.sample_path, '{}_{}_fake.png'.format(e+1, i+1)),nrow=1, padding=0) 336 | print('Generated images and saved into {}..!'.format(self.sample_path)) 337 | 338 | 339 | # Save checkpoints 340 | if (i+1) % self.model_save_step == 0: 341 | torch.save(self.Enc.state_dict(), 342 | os.path.join(self.model_path, '{}_{}_Enc.pth'.format(e+1, i+1))) 343 | torch.save(self.Dec.state_dict(), 344 | os.path.join(self.model_path, '{}_{}_Dec.pth'.format(e+1, i+1))) 345 | torch.save(self.D.state_dict(), 346 | os.path.join(self.model_path, '{}_{}_D.pth'.format(e+1, i+1))) 347 | 348 | # Decay learning rate 349 | if (e+1) > (self.num_epochs - self.num_epochs_decay): 350 | d_lr -= (self.d_lr / float(self.num_epochs_decay)) 351 | enc_lr-= (self.enc_lr / float(self.num_epochs_decay)) 352 | dec_lr-=(self.dec_lr / float(self.num_epochs_decay)) 353 | self.update_lr(enc_lr, dec_lr, d_lr) 354 | print ('Decay learning rate to enc_lr: {}, d_lr: {}.'.format(enc_lr, d_lr)) 355 | 356 | 357 | 358 | def test(self): 359 | """Generating face images owning target attributes (desired expressions) """ 360 | # Load trained models 361 | Enc_path = os.path.join(self.model_path, '{}_Enc.pth'.format(self.test_model)) 362 | Dec_path = os.path.join(self.model_path, '{}_Dec.pth'.format(self.test_model)) 363 | self.Enc.load_state_dict(torch.load(Enc_path)) 364 | self.Dec.load_state_dict(torch.load(Dec_path)) 365 | self.Enc.eval() 366 | self.Dec.eval() 367 | 368 | data_loader = self.face_data_loader 369 | 370 | for i, (real_image, real_landmark) in enumerate(data_loader): 371 | org_c = real_image[1] 372 | real_x = real_image[0] 373 | real_l = real_landmark[0] 374 | real_x = self.to_var(real_x, volatile=True) 375 | real_l = self.to_var(real_l, volatile=True) 376 | 377 | target_y_list = [] 378 | for j in range(self.y_dim): 379 | target_y = self.one_hot(torch.ones(real_x.size(0)) * j, self.y_dim) 380 | target_y_list.append(self.to_var(target_y, volatile=True)) 381 | 382 | # Target image generation 383 | fake_image_list = [real_x] 384 | for target_y in target_y_list: 385 | enc_feat = self.Enc(torch.cat([real_x, real_l], 1)) 386 | sample_result, sample_landmark = self.Dec(enc_feat, target_y) 387 | fake_image_list.append(sample_result) 388 | fake_images = torch.cat(fake_image_list, dim=3) 389 | save_path = os.path.join(self.test_path, '{}_fake.png'.format(i+1)) 390 | save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0) 391 | print('Generated images and saved into "{}"..!'.format(save_path)) 392 | 393 | -------------------------------------------------------------------------------- /tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | # This code is from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | 6 | try: 7 | from StringIO import StringIO # Python 2.7 8 | except ImportError: 9 | from io import BytesIO # Python 3.5+ 10 | 11 | 12 | class Logger(object): 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def image_summary(self, tag, images, step): 23 | """Log a list of images.""" 24 | 25 | img_summaries = [] 26 | for i, img in enumerate(images): 27 | # Write the image to a string 28 | try: 29 | s = StringIO() 30 | except: 31 | s = BytesIO() 32 | scipy.misc.toimage(img).save(s, format="png") 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 40 | 41 | # Create and write Summary 42 | summary = tf.Summary(value=img_summaries) 43 | self.writer.add_summary(summary, step) 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | counts, bin_edges = np.histogram(values, bins=bins) 50 | 51 | # Fill the fields of the histogram proto 52 | hist = tf.HistogramProto() 53 | hist.min = float(np.min(values)) 54 | hist.max = float(np.max(values)) 55 | hist.num = int(np.prod(values.shape)) 56 | hist.sum = float(np.sum(values)) 57 | hist.sum_squares = float(np.sum(values ** 2)) 58 | 59 | # Drop the start of the first bin 60 | bin_edges = bin_edges[1:] 61 | 62 | # Add bin edges and counts 63 | for edge in bin_edges: 64 | hist.bucket_limit.append(edge) 65 | for c in counts: 66 | hist.bucket.append(c) 67 | 68 | # Create and write Summary 69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 70 | self.writer.add_summary(summary, step) 71 | self.writer.flush() 72 | --------------------------------------------------------------------------------