├── Dataset.py ├── Evaluation └── ssim.py ├── IMLib ├── __init__.py └── utils.py ├── PyLib ├── __init__.py ├── const.py └── utils.py ├── README.md ├── SwapAutoEncoderAdaIN.py ├── config ├── __init__.py ├── options.py ├── test_options.py └── train_options.py ├── img ├── model.png ├── test.jpg └── train.jpg ├── requirements.txt ├── scripts ├── test_log10_10_1.sh └── train_log10_10_1.sh ├── test.py ├── tfLib ├── __init__.py ├── advloss.py ├── flowfield.py ├── gp.py ├── loss.py └── ops.py └── train.py /Dataset.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from IMLib.utils import * 3 | 4 | ATT_ID = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 5 | 'Bags_Under_Eyes': 3, 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 6 | 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, 'Blurry': 10, 7 | 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 8 | 'Double_Chin': 14, 'Eyeglasses': 15, 'Goatee': 16, 9 | 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, 10 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 11 | 'Narrow_Eyes': 23, 'No_Beard': 24, 'Oval_Face': 25, 12 | 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 13 | 'Rosy_Cheeks': 29, 'Sideburns': 30, 'Smiling': 31, 14 | 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, 15 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 16 | 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39} 17 | 18 | class CelebA(object): 19 | 20 | def __init__(self, config): 21 | super(CelebA, self).__init__() 22 | 23 | self.data_dir = config.data_dir 24 | self.label_dir = config.label_dir 25 | self.dataset_name = 'CelebA' 26 | self.height, self.width= config.img_size, config.img_size 27 | self.channel = config.output_nc 28 | self.capacity = config.capacity 29 | self.batch_size = config.batch_size 30 | self.num_threads = config.num_threads 31 | self.chosen_att_names = config.chosen_att_names 32 | 33 | self.img_names = np.genfromtxt(self.label_dir, dtype=str, usecols=0) 34 | self.img_paths = np.array([os.path.join(self.data_dir, img_name) for img_name in self.img_names[2:]]) 35 | self.labels = self.read_txt(self.label_dir) #np.genfromtxt(self.label_dir, dtype=str, usecols=range(0, 41), delimiter='[/\s:]+') 36 | 37 | self.labels = self.labels[:, np.array([ATT_ID[att_name] for att_name in self.chosen_att_names])] 38 | 39 | self.labels = np.stack([[self.labeltoCat(item) for item in self.labels]], axis=-1) 40 | 41 | assert len(self.labels) == len(self.img_paths) 42 | 43 | self.train_images_list = self.img_paths[0:29000, ...] 44 | self.test_images_list = self.img_paths[29000:-1, ...] 45 | self.train_label = self.labels[0:29000, ...] 46 | self.test_label = self.labels[29000:-1, ...] 47 | 48 | print(self.train_images_list[0], len(self.test_images_list), len(self.train_label), len(self.test_label)) 49 | 50 | def read_images(self, input_queue): 51 | 52 | content = tf.read_file(input_queue) 53 | img = tf.image.decode_jpeg(content, channels=self.channel) 54 | img = tf.cast(img, tf.float32) 55 | img = tf.image.random_flip_left_right(img) 56 | img = tf.image.central_crop(img, central_fraction=0.9) 57 | img = tf.image.resize_images(img, (self.height, self.width)) 58 | 59 | return img / 127.5 - 1.0 60 | 61 | def labeltoCat(self, l): 62 | cat = 0 63 | for i, item in enumerate(l): 64 | cat += item * pow(2, i) 65 | return cat 66 | 67 | def read_txt(self, txt_path): 68 | 69 | p = open(txt_path, 'r') 70 | lines = p.readlines() 71 | labels = [] 72 | for i, line in enumerate(lines): 73 | if i == 0 or i == 1: 74 | continue 75 | line = line.replace('\n', '') 76 | list = line.split() 77 | label = [(int(item) + 1)/2 for item in list[1:]] 78 | labels.append(label) 79 | 80 | return np.array(labels) 81 | 82 | def input(self): 83 | 84 | train_images = tf.convert_to_tensor(self.train_images_list, dtype=tf.string) 85 | 86 | train_queue = tf.train.slice_input_producer([train_images], shuffle=True) 87 | train_images_queue = self.read_images(input_queue=train_queue[0]) 88 | 89 | test_images = tf.convert_to_tensor(self.test_images_list, dtype=tf.string) 90 | test_queue = tf.train.slice_input_producer([test_images], shuffle=False) 91 | test_images_queue = self.read_images(input_queue=test_queue[0]) 92 | 93 | batch_image = tf.train.shuffle_batch([train_images_queue], 94 | batch_size=self.batch_size, 95 | capacity=self.capacity, 96 | num_threads=self.num_threads, 97 | min_after_dequeue=200) 98 | 99 | test_batch_image = tf.train.batch([test_images_queue], 100 | batch_size=self.batch_size, 101 | capacity=100, 102 | num_threads=1) 103 | 104 | return batch_image, test_batch_image 105 | -------------------------------------------------------------------------------- /Evaluation/ssim.py: -------------------------------------------------------------------------------- 1 | """Python implementation of MS-SSIM. 2 | Usage: 3 | python msssim.py --original_image=original.png --compared_image=distorted.png 4 | """ 5 | import numpy as np 6 | from scipy import signal 7 | from scipy.ndimage.filters import convolve 8 | import math 9 | import glob, os 10 | import cv2 11 | 12 | def _FSpecialGauss(size, sigma): 13 | """Function to mimic the 'fspecial' gaussian MATLAB function.""" 14 | radius = size // 2 15 | offset = 0.0 16 | start, stop = -radius, radius + 1 17 | if size % 2 == 0: 18 | offset = 0.5 19 | stop -= 1 20 | x, y = np.mgrid[offset + start:stop, offset + start:stop] 21 | assert len(x) == size 22 | g = np.exp(-((x**2 + y**2)/(2.0 * sigma**2))) 23 | return g / g.sum() 24 | 25 | def _SSIMForMultiScale(img1, img2, max_val=255, filter_size=11, 26 | filter_sigma=1.5, k1=0.01, k2=0.03): 27 | """Return the Structural Similarity Map between `img1` and `img2`. 28 | This function attempts to match the functionality of ssim_index_new.m by 29 | Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 30 | Arguments: 31 | img1: Numpy array holding the first RGB image batch. 32 | img2: Numpy array holding the second RGB image batch. 33 | max_val: the dynamic range of the images (i.e., the difference between the 34 | maximum the and minimum allowed values). 35 | filter_size: Size of blur kernel to use (will be reduced for small images). 36 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 37 | for small images). 38 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 39 | the original paper). 40 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 41 | the original paper). 42 | Returns: 43 | Pair containing the mean SSIM and contrast sensitivity between `img1` and 44 | `img2`. 45 | Raises: 46 | RuntimeError: If input images don't have the same shape or don't have four 47 | dimensions: [batch_size, height, width, depth]. 48 | """ 49 | if img1.shape != img2.shape: 50 | raise RuntimeError('Input images must have the same shape (%s vs. %s).', 51 | img1.shape, img2.shape) 52 | if img1.ndim != 4: 53 | raise RuntimeError('Input images must have four dimensions, not %d', 54 | img1.ndim) 55 | 56 | img1 = img1.astype(np.float64) 57 | img2 = img2.astype(np.float64) 58 | _, height, width, _ = img1.shape 59 | 60 | # Filter size can't be larger than height or width of images. 61 | size = min(filter_size, height, width) 62 | 63 | # Scale down sigma if a smaller filter size is used. 64 | sigma = size * filter_sigma / filter_size if filter_size else 0 65 | 66 | if filter_size: 67 | window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1)) 68 | mu1 = signal.fftconvolve(img1, window, mode='valid') 69 | mu2 = signal.fftconvolve(img2, window, mode='valid') 70 | sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') 71 | sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') 72 | sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') 73 | else: 74 | # Empty blur kernel so no need to convolve. 75 | mu1, mu2 = img1, img2 76 | sigma11 = img1 * img1 77 | sigma22 = img2 * img2 78 | sigma12 = img1 * img2 79 | 80 | mu11 = mu1 * mu1 81 | mu22 = mu2 * mu2 82 | mu12 = mu1 * mu2 83 | sigma11 -= mu11 84 | sigma22 -= mu22 85 | sigma12 -= mu12 86 | 87 | # Calculate intermediate values used by both ssim and cs_map. 88 | c1 = (k1 * max_val) ** 2 89 | c2 = (k2 * max_val) ** 2 90 | v1 = 2.0 * sigma12 + c2 91 | v2 = sigma11 + sigma22 + c2 92 | ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2))) 93 | cs = np.mean(v1 / v2) 94 | return ssim, cs 95 | 96 | 97 | def MultiScaleSSIM(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, 98 | k1=0.01, k2=0.03, weights=None): 99 | """Return the MS-SSIM score between `img1` and `img2`. 100 | This function implements Multi-Scale Structural Similarity (MS-SSIM) Image 101 | Quality Assessment according to Zhou Wang's paper, "Multi-scale structural 102 | similarity for image quality assessment" (2003). 103 | Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf 104 | Author's MATLAB implementation: 105 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 106 | Arguments: 107 | img1: Numpy array holding the first RGB image batch. 108 | img2: Numpy array holding the second RGB image batch. 109 | max_val: the dynamic range of the images (i.e., the difference between the 110 | maximum the and minimum allowed values). 111 | filter_size: Size of blur kernel to use (will be reduced for small images). 112 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 113 | for small images). 114 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 115 | the original paper). 116 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 117 | the original paper). 118 | weights: List of weights for each level; if none, use five levels and the 119 | weights from the original paper. 120 | Returns: 121 | MS-SSIM score between `img1` and `img2`. 122 | Raises: 123 | RuntimeError: If input images don't have the same shape or don't have four 124 | dimensions: [batch_size, height, width, depth]. 125 | """ 126 | if img1.shape != img2.shape: 127 | raise RuntimeError('Input images must have the same shape (%s vs. %s).', 128 | img1.shape, img2.shape) 129 | if img1.ndim != 4: 130 | raise RuntimeError('Input images must have four dimensions, not %d', 131 | img1.ndim) 132 | 133 | # Note: default weights don't sum to 1.0 but do match the paper / matlab code. 134 | weights = np.array(weights if weights else 135 | [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 136 | levels = weights.size 137 | downsample_filter = np.ones((1, 2, 2, 1)) / 4.0 138 | im1, im2 = [x.astype(np.float64) for x in [img1, img2]] 139 | mssim = np.array([]) 140 | mcs = np.array([]) 141 | for _ in range(levels): 142 | ssim, cs = _SSIMForMultiScale( 143 | im1, im2, max_val=max_val, filter_size=filter_size, 144 | filter_sigma=filter_sigma, k1=k1, k2=k2) 145 | mssim = np.append(mssim, ssim) 146 | mcs = np.append(mcs, cs) 147 | filtered = [convolve(im, downsample_filter, mode='reflect') 148 | for im in [im1, im2]] 149 | im1, im2 = [x[:, ::2, ::2, :] for x in filtered] 150 | 151 | return (np.prod(mcs[0:levels-1] ** weights[0:levels-1]) * 152 | (mssim[levels-1] ** weights[levels-1])) 153 | 154 | def mssim_score(ori_list, gen_list): 155 | 156 | score = 0.0 157 | for i in range(len(ori_list)): 158 | img1 = np.expand_dims(np.asarray(ori_list[i]), axis=0) 159 | img2 = np.expand_dims(np.asarray(gen_list[i]), axis=0) 160 | 161 | result = MultiScaleSSIM(img1, img2, max_val=255) 162 | if math.isnan(result): 163 | print('Detected NaN') 164 | else: 165 | score = score + MultiScaleSSIM(img1, img2, max_val=255) 166 | 167 | print('the ssim result is', score / len(ori_list)) 168 | return score / len(ori_list) 169 | 170 | img_path1 = '/data0/jzhang/code/GazeGAN/GazeCorrection-master/log3_28_1_19/test_sample_dir3/0' 171 | img_path2 = '/data0/jzhang/code/GazeGAN/GazeCorrection-master/log3_28_1_19/test_sample_dir3/1' 172 | 173 | img_list1 = [] 174 | img_list2 = [] 175 | 176 | path_list1 = glob.glob(os.path.join(img_path1, '*.jpg')) 177 | path_list2 = glob.glob(os.path.join(img_path2, '*.jpg')) 178 | 179 | for i, path in enumerate(path_list1): 180 | if i == 0: 181 | print(path) 182 | img = cv2.imread(path) 183 | img_list1.append(img) 184 | 185 | for i, path in enumerate(path_list2): 186 | if i == 0: 187 | print(path) 188 | img = cv2.imread(path) 189 | img_list2.append(img) 190 | 191 | mssim_score(np.array(img_list1), np.array(img_list2)) 192 | 193 | 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /IMLib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/Swapping-Autoencoder-tf/ae03c3af62842bbde90dac1b011aa9c90da794af/IMLib/__init__.py -------------------------------------------------------------------------------- /IMLib/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import imageio 7 | import scipy.misc as misc 8 | import scipy 9 | import numpy as np 10 | import cv2 11 | 12 | def save_as_gif(images_list, out_path, gif_file_name='all', save_image=False): 13 | 14 | if os.path.exists(out_path) == False: 15 | os.mkdir(out_path) 16 | # save as .png 17 | if save_image == True: 18 | for n in range(len(images_list)): 19 | file_name = '{}.png'.format(n) 20 | save_path_and_name = os.path.join(out_path, file_name) 21 | misc.imsave(save_path_and_name, images_list[n]) 22 | # save as .gif 23 | out_path_and_name = os.path.join(out_path, '{}.gif'.format(gif_file_name)) 24 | imageio.mimsave(out_path_and_name, images_list, 'GIF', duration=0.1) 25 | 26 | def get_image(image_path, crop_size=128, is_crop=False, resize_w=140, is_grayscale=False): 27 | return transform(imread(image_path , is_grayscale), crop_size, is_crop, resize_w) 28 | 29 | def transform(image, crop_size=64, is_crop=True, resize_w=140): 30 | 31 | image = scipy.misc.imresize(image, [resize_w, resize_w]) 32 | if is_crop: 33 | cropped_image = center_crop(image, crop_size) 34 | else: 35 | cropped_image = image 36 | cropped_image = scipy.misc.imresize(cropped_image , 37 | [resize_w , resize_w]) 38 | 39 | return np.array(cropped_image) / 127.5 - 1 40 | 41 | def center_crop(x, crop_h, crop_w=None): 42 | 43 | if crop_w is None: 44 | crop_w = crop_h 45 | h, w = x.shape[:2] 46 | j = int(round((h - crop_h)/2.)) 47 | i = int(round((w - crop_w)/2.)) 48 | 49 | rate = np.random.uniform(0, 1, size=1) 50 | if rate < 0.5: 51 | x = np.fliplr(x) 52 | 53 | return x[j:j+crop_h, i:i+crop_w] 54 | 55 | def transform_image(image): 56 | return (image + 1) * 127.5 57 | 58 | def save_images(images, image_path, is_verse=True): 59 | if is_verse: 60 | return imsave(inverse_transform(images), path=image_path) 61 | else: 62 | return imsave(images, path=image_path) 63 | 64 | def imsave(images, path): 65 | images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR) 66 | return cv2.imwrite(path, images) 67 | 68 | def resizeImg(img, size=list): 69 | return scipy.misc.imresize(img, size) 70 | 71 | def imread(path, is_grayscale=False): 72 | 73 | if (is_grayscale): 74 | return scipy.misc.imread(path, flatten=True).astype(np.float) 75 | else: 76 | return scipy.misc.imread(path).astype(np.float) 77 | 78 | def merge(images, size): 79 | if len(images.shape) == 3: 80 | img = images 81 | else: 82 | h, w = images.shape[1], images.shape[2] 83 | img = np.zeros((int(h * size[0]), int(w * size[1]), 3)) 84 | for idx, image in enumerate(images): 85 | i = idx % size[1] 86 | j = idx // size[1] 87 | img[j * h:j * h + h, i * w: i * w + w, :] = image 88 | return img 89 | 90 | def inverse_transform(image): 91 | result = ((image + 1) * 127.5).astype(np.uint8) 92 | result = np.clip(result, 0, 255) 93 | return result 94 | 95 | height_to_eyeball_radius_ratio = 1.1 96 | eyeball_radius_to_iris_diameter_ratio = 1.0 97 | 98 | def from_gaze2d(gaze, output_size, scale=1.0): 99 | 100 | """Generate a normalized pictorial representation of 3D gaze direction.""" 101 | gazemaps = [] 102 | oh, ow = np.round(scale * np.asarray(output_size)).astype(np.int32) 103 | oh_2 = int(np.round(0.5 * oh)) 104 | ow_2 = int(np.round(0.5 * ow)) 105 | r = int(height_to_eyeball_radius_ratio * oh_2) 106 | theta, phi = gaze 107 | theta = -theta 108 | sin_theta = np.sin(theta) 109 | cos_theta = np.cos(theta) 110 | sin_phi = np.sin(phi) 111 | cos_phi = np.cos(phi) 112 | 113 | # Draw iris 114 | eyeball_radius = int(height_to_eyeball_radius_ratio * oh_2) 115 | iris_radius_angle = np.arcsin(0.5 * eyeball_radius_to_iris_diameter_ratio) 116 | iris_radius = eyeball_radius_to_iris_diameter_ratio * eyeball_radius 117 | iris_distance = float(eyeball_radius) * np.cos(iris_radius_angle) 118 | iris_offset = np.asarray([ 119 | -iris_distance * sin_phi * cos_theta, 120 | iris_distance * sin_theta, 121 | ]) 122 | iris_centre = np.asarray([ow_2, oh_2]) + iris_offset 123 | angle = np.degrees(np.arctan2(iris_offset[1], iris_offset[0])) 124 | ellipse_max = eyeball_radius_to_iris_diameter_ratio * iris_radius 125 | ellipse_min = np.abs(ellipse_max * cos_phi * cos_theta) 126 | #gazemap = np.zeros((oh, ow), dtype=np.float32) 127 | 128 | # Draw eyeball 129 | gazemap = np.zeros((oh, ow), dtype=np.float32) 130 | gazemap = cv2.ellipse(gazemap, box=(iris_centre, (ellipse_min, ellipse_max), angle), 131 | color = 1.0 , thickness=-1, lineType=cv2.LINE_AA) 132 | #outout = cv2.circle(test_gazemap, (ow_2, oh_2), r, color=1, thickness=-1) 133 | gazemaps.append(gazemap) 134 | 135 | gazemap = np.zeros((oh, ow), dtype=np.float32) 136 | gazemap = cv2.circle(gazemap, (ow_2, oh_2), r, color=1, thickness=-1) 137 | gazemaps.append(gazemap) 138 | 139 | return np.asarray(gazemaps) 140 | 141 | 142 | if __name__ == "__main__": 143 | 144 | target_angles = [[0.0, -1.0], [1.0, -1.0], [1.0, -0.66], [1.0, -0.33], [1.0, 0.0], [1.0, 0.33], 145 | [1.0, 0.66], [1.0, 1.0], [0, 1.0], [-1.0, 1.0]] 146 | 147 | for i, angles in enumerate(target_angles): 148 | x, y = angles[0], - 1 * angles[1] 149 | gazemaps = from_gaze2d((x, y), (64, 64)) 150 | cv2.imwrite("gazemaps_{}.jpg".format(i), gazemaps[1,...] * 255.0) -------------------------------------------------------------------------------- /PyLib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/Swapping-Autoencoder-tf/ae03c3af62842bbde90dac1b011aa9c90da794af/PyLib/__init__.py -------------------------------------------------------------------------------- /PyLib/const.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # python 3.x 3 | # Filename:const.py 4 | 5 | class _const: 6 | class ConstError(TypeError): pass 7 | def __setattr__(self, name, value): 8 | if name in self.__dict__: 9 | raise self.ConstError("Can't rebind const (%s)" % name) 10 | self.__dict__[name] = value 11 | import sys 12 | sys.modules[__name__] = _const() 13 | -------------------------------------------------------------------------------- /PyLib/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os, errno 6 | import logging 7 | 8 | def list_difference(a, b): 9 | ''' 10 | :param a: 11 | :param b: 12 | :return: 13 | ''' 14 | set_a = set(a) 15 | set_b = set(b) 16 | comparison = set_a.difference(set_b) 17 | 18 | return list(comparison) 19 | 20 | def mkdir_p(path): 21 | ''' 22 | :param path: 23 | :return: 24 | ''' 25 | try: 26 | os.makedirs(path) 27 | except OSError as exc: 28 | if exc.errno == errno.EEXIST and os.path.isdir(path): 29 | pass 30 | else: 31 | raise 32 | 33 | def makefolders(subfolders): 34 | ''' 35 | create multiple folders 36 | :param subfolders: 37 | :return: 38 | ''' 39 | assert isinstance(subfolders, list) 40 | 41 | for path in subfolders: 42 | if not os.path.exists(path): 43 | mkdir_p(path) 44 | 45 | def setLogConfig(): 46 | logging.basicConfig(level=logging.INFO) 47 | logger = logging.getLogger(__name__) 48 | return logger 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Swapping-Autoencoder-tf 2 | The unofficial tensorflow implementation of Swapping Autoencoder for Deep Image Manipulation. Pdf linking: [Swapping AutoEncoder](https://arxiv.org/abs/2007.00653) 3 | 4 | ![](img/model.png) 5 | 6 | # Differences 7 | 8 | This implementation has three main differences with original paper. 9 | 10 | - trained on 256x256 images, not 512 11 | 12 | - Use AdaIn, not modulation/demodulation layer. We will update it in the next few days. 13 | 14 | ## Dependencies 15 | 16 | ```bash 17 | Python=3.6 18 | tensorflow=1.14 19 | pip install -r requirements.txt 20 | 21 | ``` 22 | Or Using Conda 23 | 24 | ```bash 25 | -conda create -name SAE python=3.6 26 | -conda install tensorflow-gpu=1.14 or higher 27 | ``` 28 | Other packages installed by pip. 29 | 30 | ## Usage 31 | 32 | - Clone this repo: 33 | ```bash 34 | git clone https://github.com/zhangqianhui/Swapping-Autoencoder-tf 35 | cd Swapping-Autoencoder-tf 36 | 37 | ``` 38 | 39 | - Download the CelebAHQ dataset 40 | 41 | Download the tar of CelebAHQ dataset from [Google Driver Linking](https://github.com/switchablenorms/CelebAMask-HQ). 42 | 43 | - Train the model using command line with python 44 | 45 | ```bash 46 | python train.py --gpu_id=0 --exper_name='log10_10_1' --data_dir='../dataset/CelebAMask-HQ/CelebA-HQ-img/' 47 | ``` 48 | - Test the model 49 | 50 | ```bash 51 | python test.py --gpu_id=0 --exper_name='log10_10_1' --data_dir='../dataset/CelebAMask-HQ/CelebA-HQ-img/' 52 | ``` 53 | 54 | Or Using scripts for training 55 | 56 | ```bash 57 | bash scripts/train_log10_10_1.sh 58 | ``` 59 | 60 | For testing 61 | 62 | ```bash 63 | bash scripts/test_log10_10_1.sh 64 | ``` 65 | 66 | ## Experiment Result with 50000 iterations 67 | 68 | Training results on CelebAHQ. 1st-4th colums are structure input, texture input, reconstruction, swapped 69 | 70 | ![](img/train.jpg) 71 | 72 | Testing results on CelebAHQ. 1st-4th colums are structure input, texture input, reconstruction, swapped 73 | 74 | ![](img/test.jpg) 75 | 76 | -------------------------------------------------------------------------------- /SwapAutoEncoderAdaIN.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from Dataset import save_images 4 | from tfLib.ops import * 5 | from tfLib.loss import * 6 | from tfLib.advloss import * 7 | import os 8 | import functools 9 | 10 | class SAE(object): 11 | 12 | # build model 13 | def __init__(self, dataset, opt): 14 | 15 | self.dataset = dataset 16 | self.opt = opt 17 | # placeholder 18 | self.x = tf.placeholder(tf.float32, 19 | [self.opt.batch_size, self.opt.img_size, self.opt.img_size, self.opt.input_nc]) 20 | #pos 21 | self.pfake = tf.placeholder(tf.float32, [self.opt.batch_size // 2 * self.opt.crop_n, 4]) 22 | self.preal = tf.placeholder(tf.float32, [self.opt.batch_size // 2 * pow(self.opt.crop_n, 2), 4]) 23 | self.preal2 = tf.placeholder(tf.float32, [self.opt.batch_size // 2 * self.opt.crop_n, 4]) 24 | 25 | self.lr_decay = tf.placeholder(tf.float32, None, name='lr_decay') 26 | self.noise_strength = tf.placeholder(tf.float32, [], name='noise') 27 | 28 | def build_model(self): 29 | 30 | self.x_list = tf.split(self.x, num_or_size_splits=2, axis=0) 31 | self.x1 = self.x_list[0] 32 | self.y = self.x_list[1] 33 | self.sx, self.tx = self.Encoder(self.x1) 34 | self.sy, self.ty = self.Encoder(self.y) 35 | self._x = self.G(self.sx, self.tx) 36 | self._xy = self.G(self.sx, self.ty) 37 | 38 | self.g_logits = self.D(tf.concat([self._x, self._xy], axis=0)) 39 | self.d_logits = self.D(self.x) 40 | 41 | #recon loss 42 | self.recon_loss = L1(self._x, self.x1) 43 | 44 | d_loss_fun, g_loss_fun = get_adversarial_loss(self.opt.loss_type) 45 | 46 | self.d_gan_loss = d_loss_fun(self.d_logits, self.g_logits) 47 | self.g_gan_loss = g_loss_fun(self.g_logits) 48 | self.gp_x_loss, self.logits_x = self.gradient_penalty_just_real(self.x) 49 | 50 | # swapping loss 51 | self.y_local_list = self.croplocal(self.y, self.preal, num_or_size_splits=pow(self.opt.crop_n, 2)) 52 | self._xy_local_list = self.croplocal(self._xy, self.pfake, num_or_size_splits=self.opt.crop_n) 53 | self.y2_local_list = self.croplocal(self.y, self.preal2, num_or_size_splits=self.opt.crop_n) 54 | 55 | self.co_fake_logits = self.Co_D(self._xy_local_list, self.y_local_list) 56 | self.co_real_logits = self.Co_D(self.y2_local_list, self.y_local_list) 57 | self.co_gan_loss = d_loss_fun(self.co_real_logits, self.co_fake_logits) 58 | self.g_co_gan_loss = g_loss_fun(self.co_fake_logits) 59 | self.gp_co_loss = self.gradient_penalty_just_real(self.y2_local_list, self.y_local_list, is_d=False) 60 | 61 | self.co_gan_loss = self.co_gan_loss 62 | self.g_co_gan_loss = self.g_co_gan_loss 63 | 64 | self.D_loss =self.d_gan_loss + self.opt.lam_gp_d * self.gp_x_loss 65 | self.G_loss = self.g_gan_loss + self.recon_loss + self.g_co_gan_loss 66 | self.Co_loss = self.co_gan_loss + self.opt.lam_gp_co * self.gp_co_loss 67 | 68 | def croplocal(self, x, p, num_or_size_splits=8): 69 | 70 | preal_list = tf.split(p, num_or_size_splits=num_or_size_splits, axis=0) 71 | x_local_list = [] 72 | for i in range(len(preal_list)): 73 | y_local = self.crop_resize(x, tf.cast(preal_list[i], dtype=tf.float32)) 74 | x_local_list.append(y_local) 75 | x_local_list = tf.stack(x_local_list, 1) 76 | _, _, h, w, c = x_local_list.get_shape().as_list() 77 | x_local_list = tf.reshape(x_local_list, shape=[-1, h, w, c]) 78 | 79 | return x_local_list 80 | 81 | def gradient_penalty_just_real(self, x, y=None, is_d=True): 82 | 83 | if is_d: 84 | discri_logits = self.D(x) 85 | gradients = tf.gradients(tf.reduce_sum(discri_logits), [x])[0] 86 | slopes = tf.reduce_sum(tf.square(gradients), [1, 2, 3]) 87 | return 0.5 * tf.reduce_mean(slopes), tf.reduce_sum(discri_logits) 88 | else: 89 | discri_logits = self.Co_D(x, y) 90 | gradients = tf.gradients(tf.reduce_sum(discri_logits), [x])[0] 91 | slopes = tf.reduce_sum(tf.square(gradients), [1, 2, 3]) 92 | return 0.5 * tf.reduce_mean(slopes) 93 | 94 | def build_test_model(self): 95 | 96 | self.x_list = tf.split(self.x, num_or_size_splits=2, axis=0) 97 | self.x1 = self.x_list[0] 98 | self.y = self.x_list[1] 99 | self.sx, self.tx = self.Encoder(self.x1) 100 | self.sy, self.ty = self.Encoder(self.y) 101 | self._x = self.G(self.sx, self.tx) 102 | self._xy = self.G(self.sx, self.ty) 103 | 104 | def crop_resize(self, input, boxes): 105 | shape = [int(item) for item in input.shape.as_list()] 106 | return tf.image.crop_and_resize(input, boxes=boxes, box_ind=list(range(0, int(shape[0]))), 107 | crop_size=[int(shape[1] / 4), int(shape[2] / 4)]) 108 | 109 | def train(self): 110 | 111 | self.t_vars = tf.trainable_variables() 112 | 113 | self.d_vars = getTrainVariable(vars=self.t_vars, scope='Discriminator') 114 | self.g_vars = getTrainVariable(vars=self.t_vars, scope='Generator') 115 | self.en_vars = getTrainVariable(vars=self.t_vars, scope='Encoder') 116 | self.co_vars = getTrainVariable(vars=self.t_vars, scope='Co-occurrence') 117 | 118 | assert len(self.t_vars) == len(self.d_vars + self.g_vars + self.en_vars + self.co_vars) 119 | 120 | self.saver = tf.train.Saver() 121 | 122 | opti_D = tf.train.AdamOptimizer(self.opt.lr_d, beta1=self.opt.beta1, beta2=self.opt.beta2). \ 123 | minimize(loss=self.D_loss, var_list=self.d_vars) 124 | opti_G = tf.train.AdamOptimizer(self.opt.lr_g, beta1=self.opt.beta1, beta2=self.opt.beta2). \ 125 | minimize(loss=self.G_loss, var_list=self.g_vars + self.en_vars) 126 | opti_Co = tf.train.AdamOptimizer(self.opt.lr_co, beta1=self.opt.beta1, beta2=self.opt.beta2).minimize( 127 | loss=self.Co_loss, var_list=self.co_vars) 128 | 129 | init = tf.global_variables_initializer() 130 | config = tf.ConfigProto() 131 | config.gpu_options.allow_growth = True 132 | 133 | with tf.Session(config=config) as sess: 134 | 135 | sess.run(init) 136 | ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir) 137 | if ckpt and ckpt.model_checkpoint_path: 138 | start_step = int(ckpt.model_checkpoint_path.split('model_', 2)[1].split('.', 2)[0]) 139 | self.saver.restore(sess, ckpt.model_checkpoint_path) 140 | else: 141 | start_step = 0 142 | 143 | step = start_step 144 | print("Start read dataset") 145 | 146 | tr_img, te_img = self.dataset.input() 147 | coord = tf.train.Coordinator() 148 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 149 | 150 | _te_img = sess.run(te_img) 151 | print("Start entering the looping") 152 | while step <= self.opt.niter: 153 | 154 | if step < self.opt.niter: 155 | lr_decay = (self.opt.niter - step) / self.opt.niter 156 | else: 157 | lr_decay = 0 158 | 159 | _tr_img = sess.run(tr_img) 160 | preal = self.get_pos(self.opt.batch_size // 2 * self.opt.crop_n * self.opt.crop_n) 161 | pfake = self.get_pos(self.opt.batch_size // 2 * self.opt.crop_n) 162 | preal2 = self.get_pos(self.opt.batch_size // 2 * self.opt.crop_n) 163 | # noise_strength = self.opt.initial_noise_factor * \ 164 | # max(0.0, 1.0 - (step / self.opt.niter) / self.opt.noise_ramp_length) ** 2 165 | 166 | f_d = {self.x: _tr_img, 167 | self.pfake: pfake, 168 | self.preal: preal, 169 | self.preal2: preal2, 170 | self.lr_decay: lr_decay, 171 | self.noise_strength: 0.0} 172 | 173 | # optimize G 174 | sess.run(opti_Co, feed_dict=f_d) 175 | sess.run(opti_D, feed_dict=f_d) 176 | sess.run(opti_G, feed_dict=f_d) 177 | 178 | if step % 500 == 0: 179 | 180 | o_loss = sess.run([self.D_loss, self.G_loss, self.Co_loss, self.d_gan_loss, 181 | self.g_gan_loss, self.co_gan_loss, self.gp_x_loss, self.gp_co_loss, self.recon_loss], feed_dict=f_d) 182 | print("step %d d_loss=%.4f, g_loss=%.4f, co_loss=%.4f, d_gan_loss=%.4f, " 183 | "g_gan_loss=%.4f, co_gan_loss=%.4f, gp_x_loss=%.4f, gp_co_loss=%.4f, recon_loss=%.4f, lr_decay=%.4f" % (step, 184 | o_loss[0], o_loss[1], o_loss[2], o_loss[3], o_loss[4], o_loss[5], o_loss[6], o_loss[7], o_loss[8], lr_decay)) 185 | 186 | if np.mod(step, 500) == 0: 187 | 188 | tr_o = sess.run([self.x1, self.y, self._x, self._xy, self._xy_local_list, self.y2_local_list], feed_dict=f_d) 189 | _tr_o = self.Transpose(np.array([tr_o[0], tr_o[1], tr_o[2], tr_o[3]])) 190 | 191 | f_d = {self.x: _te_img, self.lr_decay: lr_decay, self.noise_strength: 0} 192 | te_o = sess.run([self.x1, self.y, self._x, self._xy], feed_dict=f_d) 193 | _te_o = self.Transpose(np.array([te_o[0], te_o[1], te_o[2], te_o[3]])) 194 | _local_o = self.Transpose(np.array([tr_o[4], tr_o[5]])) 195 | 196 | save_images(_tr_o, '{}/{:02d}_tr.jpg'.format(self.opt.sample_dir, step)) 197 | save_images(_te_o, '{}/{:02d}_te.jpg'.format(self.opt.sample_dir, step)) 198 | save_images(_local_o, '{}/{:02d}_te_local.jpg'.format(self.opt.sample_dir, step)) 199 | 200 | if np.mod(step, self.opt.save_model_freq) == 0 and step != 0: 201 | self.saver.save(sess, os.path.join(self.opt.checkpoints_dir, 'model_{:06d}.ckpt'.format(step))) 202 | step += 1 203 | 204 | save_path = self.saver.save(sess, os.path.join(self.opt.checkpoints_dir, 'model_{:06d}.ckpt'.format(step))) 205 | coord.request_stop() 206 | coord.join(threads) 207 | 208 | print("Model saved in file: %s" % save_path) 209 | 210 | def test(self): 211 | 212 | init = tf.global_variables_initializer() 213 | config = tf.ConfigProto() 214 | config.gpu_options.allow_growth = True 215 | 216 | with tf.Session(config=config) as sess: 217 | 218 | sess.run(init) 219 | self.saver = tf.train.Saver() 220 | ckpt = tf.train.get_checkpoint_state(self.opt.checkpoints_dir) 221 | print('Load checkpoint', ckpt) 222 | if ckpt and ckpt.model_checkpoint_path: 223 | self.saver.restore(sess, ckpt.model_checkpoint_path) 224 | print('Load Succeed!') 225 | else: 226 | print('Do not exists any checkpoint,Load Failed!') 227 | exit() 228 | 229 | _, test_batch_image = self.dataset.input() 230 | coord = tf.train.Coordinator() 231 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 232 | 233 | batch_num = self.opt.test_num // self.opt.batch_size 234 | for i in range(batch_num): 235 | f_te_img = sess.run(test_batch_image) 236 | f_d = {self.x: f_te_img} 237 | output = sess.run([self.x1, self.y, self._x, self._xy], feed_dict=f_d) 238 | _te_o = self.Transpose(np.array([output[0], output[1], output[2], output[3]])) 239 | save_images(_te_o, '{}/{:02d}_o.jpg'.format(self.opt.test_sample_dir, i)) 240 | 241 | coord.request_stop() 242 | coord.join(threads) 243 | 244 | def Transpose(self, list): 245 | refined_list = np.transpose(np.array(list), axes=[1, 2, 0, 3, 4]) 246 | refined_list = np.reshape(refined_list, [refined_list.shape[0] * refined_list.shape[1], 247 | refined_list.shape[2] * refined_list.shape[3], -1]) 248 | return refined_list 249 | 250 | def D(self, x): 251 | 252 | n_layers_d = self.opt.n_layers_d 253 | ndf = self.opt.ndf 254 | conv2d_first = functools.partial(conv2d, k=1, s=1, output_dim=ndf) 255 | conv2d_middle = functools.partial(conv2d, k=3, s=1, padding='VALID') 256 | ful_final1 = functools.partial(fully_connect) 257 | ful_final2 = functools.partial(fully_connect, output_dim=1) 258 | ResBlock_ = functools.partial(Resblock, relu_type='lrelu', padding='SAME', ds=True, use_IN=False) 259 | with tf.variable_scope("Discriminator", reuse=tf.AUTO_REUSE): 260 | x = conv2d_first(x, output_dim=ndf, scope='conv_first') 261 | for i in range(n_layers_d): 262 | c_dim = np.minimum(self.opt.ndf * np.power(2, i + 1), 256) 263 | x = ResBlock_(x, o_dim=c_dim, scope='r_en{}'.format(i)) 264 | x = lrelu(conv2d_middle(lrelu(x), output_dim=c_dim, scope='conv_middle')) 265 | x = tf.reshape(x, [-1, np.prod([d.value for d in x.shape[1:]])]) 266 | x = lrelu(ful_final1(x, output_dim=c_dim, scope='ful_final1')) 267 | x = ful_final2(x, scope='ful_final2') 268 | 269 | return x 270 | 271 | def Co_D(self, x_local, y_local_list): 272 | 273 | ful_s = functools.partial(fully_connect, output_dim=2048) 274 | ful_m = functools.partial(fully_connect, output_dim=1024, scope='ful_m') 275 | ful_f = functools.partial(fully_connect, output_dim=1, scope='ful_f') 276 | with tf.variable_scope("Co-occurrence", reuse=tf.AUTO_REUSE): 277 | x_local_fp = self.Patch_Encoder(x_local) 278 | y_local_fp = self.Patch_Encoder(y_local_list) 279 | _, h, w, c = y_local_fp.get_shape().as_list() 280 | y_local_fp = tf.reshape(y_local_fp, shape=[-1, self.opt.crop_n, h, w, c]) 281 | y_local_fp = tf.reduce_mean(y_local_fp, axis=1) 282 | fp = tf.reshape(tf.concat([x_local_fp, y_local_fp], axis=-1), 283 | [-1, x_local_fp.shape[-1] + y_local_fp.shape[-1]]) 284 | 285 | fp = lrelu(ful_s(fp, scope='ful_s1')) 286 | fp = lrelu(ful_s(fp, scope='ful_s2')) 287 | fp = lrelu(ful_m(fp)) 288 | logits = ful_f(fp) 289 | 290 | return logits 291 | 292 | def Patch_Encoder(self, x): 293 | 294 | n_layers_co_d = self.opt.n_layers_co_d 295 | ncodf = self.opt.ncodf 296 | conv2d_first = functools.partial(conv2d, k=3, s=1, output_dim=ncodf) 297 | conv2d_middle = functools.partial(conv2d, k=3, s=1, padding='VALID') 298 | ResBlockDs = functools.partial(Resblock, relu_type='lrelu', padding='SAME', ds=True, use_IN=False) 299 | ResBlock = functools.partial(Resblock, relu_type='lrelu', padding='SAME', ds=False, use_IN=False) 300 | with tf.variable_scope("Co-occurrence", reuse=tf.AUTO_REUSE): 301 | x = conv2d_first(x, output_dim=ncodf, scope='conv_first') 302 | for i in range(n_layers_co_d): 303 | c_dim = [64, 128, 256, 384] 304 | x = ResBlockDs(x, o_dim=c_dim[i], scope='rds{}'.format(i)) 305 | x = lrelu(ResBlock(x, o_dim=c_dim[-1] * 2, scope='r_1')) 306 | x = lrelu(conv2d_middle(x, output_dim=c_dim[-1], scope='conv_middle')) 307 | 308 | return x 309 | 310 | def Encoder(self, x_init): 311 | 312 | nef = self.opt.nef 313 | n_layers_e = self.opt.n_layers_e 314 | conv2d_first = functools.partial(conv2d, k=1, s=1, output_dim=nef) 315 | conv2d_final = functools.partial(conv2d, k=1, s=1) 316 | ful = functools.partial(fully_connect, output_dim=512) 317 | with tf.variable_scope("Encoder", reuse=tf.AUTO_REUSE): 318 | 319 | x = x_init 320 | x = conv2d_first(x, scope='conv') 321 | for i in range(n_layers_e): 322 | c_dim = np.minimum(self.opt.nef * np.power(2, i + 1), 256) 323 | x = Resblock(x, o_dim=c_dim, use_IN=False, scope='r_en{}'.format(i)) 324 | 325 | #stru 326 | s = lrelu(conv2d(lrelu(x), output_dim=256, k=1, s=1, scope='conv_s1')) 327 | s = conv2d_final(s, output_dim=8, scope='conv_s2') 328 | 329 | #texture 330 | t = lrelu(conv2d(x, output_dim=256, k=1, padding='VALID', scope='conv_t1')) 331 | t = lrelu(conv2d(t, output_dim=512, k=1, padding='VALID', scope='conv_t2')) 332 | t = avgpool2d(t, k=t.shape[-2]) 333 | t = ful(tf.squeeze(t, axis=[1,2]), scope='ful_t3') 334 | return s, t 335 | 336 | def G(self, structure, texture): 337 | 338 | n_layers_g = self.opt.n_layers_g 339 | conv2d_final = functools.partial(conv2d, k=1, s=1, padding='VALID', output_dim=self.opt.output_nc) 340 | RAA = functools.partial(Resblock_AdaIn_Affline_layers, style_code=texture) 341 | with tf.variable_scope("Generator", reuse=tf.AUTO_REUSE): 342 | 343 | s = structure 344 | for i in range(2): 345 | c_dim = 128 * (i+1) 346 | s = RAA(s, o_dim=c_dim, us=False, scope='AdaInAffline_{}'.format(i)) 347 | 348 | for i in range(n_layers_g): 349 | c_dim = [512, 512, 256, 128] 350 | s = RAA(s, o_dim=c_dim[i], scope='AdaInAfflineD_{}'.format(i)) 351 | 352 | s = tf.nn.tanh(conv2d_final(lrelu(s), scope='f')) 353 | return s 354 | 355 | def get_pos(self, batch_size): 356 | 357 | batch_pos = [] 358 | for i in range(batch_size): 359 | pos = [] 360 | rate = np.random.uniform(4, 8, size=1) 361 | 362 | wh = self.opt.img_size // rate 363 | x = np.random.randint(wh // 2, self.opt.img_size - wh //2) 364 | y = np.random.randint(wh // 2, self.opt.img_size - wh //2) 365 | 366 | center = (x, y) 367 | scale = center[1] - wh // 2 368 | down_scale = center[1] + wh // 2 369 | l1_1 = int(scale) 370 | u1_1 = int(down_scale) 371 | 372 | scale = center[0] - wh // 2 373 | down_scale = center[0] + wh // 2 374 | l1_2 = int(scale) 375 | u1_2 = int(down_scale) 376 | 377 | pos.append(float(l1_1) / self.opt.img_size) 378 | pos.append(float(l1_2) / self.opt.img_size) 379 | pos.append(float(u1_1) / self.opt.img_size) 380 | pos.append(float(u1_2) / self.opt.img_size) 381 | batch_pos.append(pos) 382 | 383 | return np.array(batch_pos) 384 | 385 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/Swapping-Autoencoder-tf/ae03c3af62842bbde90dac1b011aa9c90da794af/config/__init__.py -------------------------------------------------------------------------------- /config/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PyLib.utils import makefolders 4 | from abc import abstractmethod 5 | 6 | class BaseOptions(): 7 | def __init__(self): 8 | self.initialized = False 9 | def initialize(self, parser): 10 | 11 | default_chosen_att_names = ['Male'] 12 | parser.add_argument('--chosen_att_names', nargs='+', default=default_chosen_att_names) 13 | parser.add_argument('--data_dir', type=str, 14 | default='../../dataset/CelebAMask-HQ/CelebA-HQ-img/', help='path to images') 15 | parser.add_argument('--vgg_path', type=str, 16 | default='./vgg_16.ckpt', help='vgg path for perceptual loss') 17 | parser.add_argument('--inception_path', type=str, default='../pretrained/') 18 | parser.add_argument('--label_dir', type=str, 19 | default='../../dataset/CelebAMask-HQ/CelebAMask-HQ-attribute-anno.txt', help='path to images') 20 | parser.add_argument('--gpu_id', type=str, default='3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 21 | parser.add_argument('--img_size', type=int, default=256, help='scale images to this size') 22 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 23 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 24 | parser.add_argument('--nef', type=int, default=32, help='# of style encoder filters in first layer') 25 | parser.add_argument('--ncodf', type=int, default=16, help='# of co d filters in first fully layers') 26 | parser.add_argument('--ngf', type=int, default=64, help='# of generator filters in first conv layer') 27 | parser.add_argument('--ndf', type=int, default=16, help='# of discriminator filters in first conv layer') 28 | parser.add_argument('--n_layers_e', type=int, default=4, help='layers of texture encoder') 29 | parser.add_argument('--n_layers_g', type=int, default=4, help='layers of generator') 30 | parser.add_argument('--n_layers_d', type=int, default=6, help='layers of d model') 31 | parser.add_argument('--n_layers_co_d', type=int, default=4, help='layers of co-d') 32 | parser.add_argument('--n_blocks', type=int, default=4, help='layers of residual block') 33 | parser.add_argument('--n_latent', type=int, default=16, help='the dim for latent code z') 34 | parser.add_argument('--exper_name', type=str, default='log7_6', help='name of the experiment. It decides where to store samples and models') 35 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 36 | parser.add_argument('--log_dir', type=str, default='./logs', help='logs for tensorboard') 37 | parser.add_argument('--capacity', type=int, default=500, help='capacity for queue in training') 38 | parser.add_argument('--num_threads', type=int, default=5, help='thread for reading data in training') 39 | parser.add_argument('--sample_dir', type=str, default='./sample_dir', help='dir for sample images') 40 | parser.add_argument('--test_sample_dir', type=str, default='test_sample_dir', help='test sample images are saved here') 41 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 42 | parser.add_argument('--initial_noise_factor', type=float, default=0.05, help='initial nosise factor') 43 | parser.add_argument('--noise_ramp_length', type=float, default=0.75, help='noise ramp length') 44 | 45 | self.initialized = True 46 | return parser 47 | 48 | def gather_options(self): 49 | # initialize parser with basic options 50 | if not self.initialized: 51 | parser = argparse.ArgumentParser( 52 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 53 | parser = self.initialize(parser) 54 | # get the basic options 55 | opt, _ = parser.parse_known_args() 56 | self.parser = parser 57 | return parser.parse_args() 58 | 59 | def print_options(self, opt): 60 | 61 | opt.checkpoints_dir = os.path.join(opt.exper_name, opt.checkpoints_dir) 62 | opt.sample_dir = os.path.join(opt.exper_name, opt.sample_dir) 63 | 64 | opt.test_sample_dir = os.path.join(opt.exper_name, opt.test_sample_dir) 65 | opt.test_sample_dir0 = os.path.join(opt.test_sample_dir, '0') 66 | opt.test_sample_dir1 = os.path.join(opt.test_sample_dir, '1') 67 | 68 | opt.log_dir = os.path.join(opt.exper_name, opt.log_dir) 69 | makefolders([opt.inception_path, opt.checkpoints_dir, 70 | opt.sample_dir, opt.test_sample_dir, opt.log_dir, opt.test_sample_dir0, opt.test_sample_dir1]) 71 | 72 | message = '' 73 | message += '----------------- Options ---------------\n' 74 | for k, v in sorted(vars(opt).items()): 75 | comment = '' 76 | default = self.parser.get_default(k) 77 | if v != default: 78 | comment = '\t[default: %s]' % str(default) 79 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 80 | message += '----------------- End -------------------' 81 | 82 | # save to the disk 83 | if opt.isTrain: 84 | file_name = os.path.join(opt.checkpoints_dir, 'opt.txt') 85 | else: 86 | file_name = os.path.join(opt.checkpoints_dir, 'test_opt.txt') 87 | with open(file_name, 'wt') as opt_file: 88 | opt_file.write(message) 89 | opt_file.write('\n') 90 | 91 | @abstractmethod 92 | def parse(self): 93 | pass 94 | -------------------------------------------------------------------------------- /config/test_options.py: -------------------------------------------------------------------------------- 1 | from .options import BaseOptions 2 | 3 | class TestOptions(BaseOptions): 4 | 5 | def initialize(self, parser): 6 | 7 | parser = BaseOptions.initialize(self, parser) 8 | parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 9 | parser.add_argument('--pos_number', type=int, default=4, help='position') 10 | parser.add_argument('--use_sp', action='store_true', help='use spetral normalization') 11 | parser.add_argument('--test_num', type=int, default=300, help='the number of test samples') 12 | parser.add_argument('--n_att', type=float, default=4, help='number of attribute') 13 | parser.add_argument('--crop_n', type=int, default=8, help='numbers for crops') 14 | 15 | self.isTrain = False 16 | return parser 17 | 18 | def parse(self): 19 | 20 | opt = self.gather_options() 21 | opt.isTrain = self.isTrain 22 | self.print_options(opt) 23 | self.opt = opt 24 | 25 | return self.opt -------------------------------------------------------------------------------- /config/train_options.py: -------------------------------------------------------------------------------- 1 | from .options import BaseOptions 2 | 3 | class TrainOptions(BaseOptions): 4 | 5 | def initialize(self, parser): 6 | 7 | parser = BaseOptions.initialize(self, parser) 8 | parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 9 | parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results') 10 | parser.add_argument('--save_model_freq', type=int, default=10000, help='frequency of saving checkpoints') 11 | parser.add_argument('--batch_size', type=int, default=16, help='input batch size') 12 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 13 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 14 | parser.add_argument('--niter', type=int, default=100000, help='# of iter at starting learning rate') 15 | parser.add_argument('--niter_decay', type=int, default=50000, help='# of iter to linearly decay learning rate to zero') 16 | parser.add_argument('--lr_d', type=float, default=2e-3, help='initial learning rate for Adam in d') 17 | parser.add_argument('--lr_g', type=float, default=2e-3, help='initial learning rate for Adam in g') 18 | parser.add_argument('--lr_co', type=float, default=1e-4, help='initial learning rate from adam in co-discriminator') 19 | parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for adam') 20 | parser.add_argument('--beta2', type=float, default=0.99, help='beta2 for adam') 21 | parser.add_argument('--loss_type', type=str, default='softplus', 22 | choices=['gan', 'hinge', 'wgan_gp', 'lsgan', 'softplus'], help='using type of gan loss') 23 | parser.add_argument('--loss_type2', type=str, default='lsgan', 24 | choices=['gan', 'hinge', 'wgan_gp', 'lsgan', 'softplus'], help='using type of gan loss') 25 | parser.add_argument('--gp_type', type=str, default='R1_regu', choices=['Dirac', 'wgan-gp', 'R1_regu'], help='gp type') 26 | parser.add_argument('--lam_gp_d', type=float, default=10.0, help='weight for gradient penalty of d') 27 | parser.add_argument('--lam_gp_co', type=float, default=1.0, help='wegiht for gradient penalty of co d') 28 | parser.add_argument('--pos_number', type=int, default=4, help='position') 29 | parser.add_argument('--test_num', type=int, default=300, help='the number of test samples') 30 | parser.add_argument('--crop_n', type=int, default=8, help='numbers for crops') 31 | parser.add_argument('--d_reg_every', type=int, default=1, help='l1 reg optimization every d') 32 | parser.add_argument('--g_reg_every', type=int, default=1, help='l1 reg optimization every g') 33 | 34 | self.isTrain = True 35 | return parser 36 | 37 | def parse(self): 38 | 39 | opt = self.gather_options() 40 | opt.isTrain = self.isTrain 41 | self.print_options(opt) 42 | self.opt = opt 43 | 44 | return self.opt -------------------------------------------------------------------------------- /img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/Swapping-Autoencoder-tf/ae03c3af62842bbde90dac1b011aa9c90da794af/img/model.png -------------------------------------------------------------------------------- /img/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/Swapping-Autoencoder-tf/ae03c3af62842bbde90dac1b011aa9c90da794af/img/test.jpg -------------------------------------------------------------------------------- /img/train.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/Swapping-Autoencoder-tf/ae03c3af62842bbde90dac1b011aa9c90da794af/img/train.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.4.0.44 2 | Pillow==7.2.0 3 | scipy==1.2.1 4 | setproctitle==1.1.10 5 | six==1.15.0 6 | tensorboard==1.14.0 7 | tensorflow==1.14.0 8 | numpy==1.19.1 9 | tensorflow-estimator==1.14.0 10 | -------------------------------------------------------------------------------- /scripts/test_log10_10_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python test.py --gpu_id=0 --exper_name='log10_10_1' -------------------------------------------------------------------------------- /scripts/train_log10_10_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py --gpu_id=7 --exper_name='log10_10_1' -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | from Dataset import CelebA 7 | from SwapAutoEncoderAdaIN import SAE 8 | from config.test_options import TestOptions 9 | 10 | opt = TestOptions().parse() 11 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_id) 12 | 13 | if __name__ == "__main__": 14 | 15 | dataset = CelebA(opt) 16 | gaze_gan = SAE(dataset, opt) 17 | gaze_gan.build_test_model() 18 | gaze_gan.test() 19 | -------------------------------------------------------------------------------- /tfLib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqianhui/Swapping-Autoencoder-tf/ae03c3af62842bbde90dac1b011aa9c90da794af/tfLib/__init__.py -------------------------------------------------------------------------------- /tfLib/advloss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import functools 5 | import tensorflow as tf 6 | 7 | def get_gan_losses_fn(): 8 | bce = functools.partial(tf.nn.sigmoid_cross_entropy_with_logits) 9 | def d_loss_fn(r_logit, f_logit): 10 | r_loss = bce(labels=tf.ones_like(r_logit), logits=r_logit) 11 | f_loss = bce(labels=tf.zeros_like(f_logit), logits=f_logit) 12 | return tf.reduce_mean(r_loss + f_loss) 13 | 14 | def g_loss_fn(f_logit): 15 | f_loss = bce(labels=tf.ones_like(f_logit), logits=f_logit) 16 | return tf.reduce_mean(f_loss) 17 | 18 | return d_loss_fn, g_loss_fn 19 | 20 | def get_hinge_loss(): 21 | def loss_hinge_dis(d_real_logits, d_fake_logits): 22 | loss = tf.reduce_mean(tf.nn.relu(1.0 - d_real_logits)) 23 | loss += tf.reduce_mean(tf.nn.relu(1.0 + d_fake_logits)) 24 | return loss 25 | 26 | def loss_hinge_gen(d_fake_logits): 27 | loss = - tf.reduce_mean(d_fake_logits) 28 | return loss 29 | 30 | return loss_hinge_dis, loss_hinge_gen 31 | 32 | def get_softplus_loss(): 33 | 34 | def loss_dis(d_real_logits, d_fake_logits): 35 | l1 = tf.reduce_mean(tf.nn.softplus(-d_real_logits)) 36 | l2 = tf.reduce_mean(tf.nn.softplus(d_fake_logits)) 37 | return l1 + l2 38 | 39 | def loss_gen(d_fake_logits): 40 | return tf.reduce_mean(tf.nn.softplus(-d_fake_logits)) 41 | 42 | return loss_dis, loss_gen 43 | 44 | def get_lsgan_loss(): 45 | 46 | def d_lsgan_loss(d_real_logits, d_fake_logits): 47 | return tf.reduce_mean((d_real_logits - 0.9)*2) \ 48 | + tf.reduce_mean((d_fake_logits)*2) 49 | 50 | def g_lsgan_loss(d_fake_logits): 51 | return tf.reduce_mean((d_fake_logits - 0.9)*2) 52 | 53 | return d_lsgan_loss, g_lsgan_loss 54 | 55 | def get_wgan_losses_fn(): 56 | 57 | def d_loss_fn(r_logit, f_logit): 58 | r_loss = - tf.reduce_mean(r_logit) 59 | f_loss = tf.reduce_mean(f_logit) 60 | return r_loss + f_loss 61 | 62 | def g_loss_fn(f_logit): 63 | f_loss = - tf.reduce_mean(f_logit) 64 | return f_loss 65 | 66 | return d_loss_fn, g_loss_fn 67 | 68 | def get_adversarial_loss(mode): 69 | 70 | print("mode", mode) 71 | if mode == 'gan': 72 | return get_gan_losses_fn() 73 | elif mode == 'hinge': 74 | return get_hinge_loss() 75 | elif mode == 'lsgan': 76 | return get_lsgan_loss() 77 | elif mode == 'softplus': 78 | return get_softplus_loss() 79 | elif mode == 'wgan_gp': 80 | return get_wgan_losses_fn() 81 | -------------------------------------------------------------------------------- /tfLib/flowfield.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def bilinear_sample(input, flow, name): 5 | # reference to spatial transform network 6 | # 1.details can be found in office release: 7 | # https://github.com/tensorflow/models/blob/master/research/transformer/spatial_transformer.py 8 | # 2.maybe another good implement can be found in: 9 | # https://github.com/kevinzakka/spatial-transformer-network/blob/master/transformer.py 10 | # but this one maybe contain some problems, go to --> https://github.com/kevinzakka/spatial-transformer-network/issues/10 11 | with tf.variable_scope(name): 12 | N, iH, iW, iC = input.get_shape().as_list() 13 | _, fH, fW, fC = flow.get_shape().as_list() 14 | 15 | assert iH == fH and iW == fW 16 | # re-order & reshape: N,H,W,C --> N,C,H*W , shape= ( 16,2,3500 ) 17 | flow = tf.reshape(tf.transpose(flow, [0, 3, 1, 2]), [-1, fC, fH * fW]) 18 | # get mesh-grid, 2,H*W 19 | indices_grid = meshgrid(iH, iW) 20 | transformed_grid = tf.add(flow, indices_grid) 21 | x_s = tf.slice(transformed_grid, [0, 0, 0], [-1, 1, -1]) # x_s should be (16,1,3500) 22 | y_s = tf.slice(transformed_grid, [0, 1, 0], [-1, 1, -1]) # y_s should be ( 16,1,3500) 23 | # look tf.slice with ctrl , to figure out its meanning 24 | x_s_flatten = tf.reshape(x_s, [-1]) # should be (16*3500) 25 | y_s_flatten = tf.reshape(y_s, [-1]) # should be (16*3500) 26 | transformed_image = interpolate(input, x_s_flatten, y_s_flatten, iH, iW, 'interpolate') 27 | # print(transformed_image.get_shape().as_list()) 28 | transformed_image = tf.reshape(transformed_image, [N, iH, iW, iC]) 29 | 30 | return transformed_image 31 | 32 | def meshgrid(height, width, ones_flag=None): 33 | 34 | with tf.variable_scope('meshgrid'): 35 | y_linspace = tf.linspace(-1., 1., height) 36 | x_linspace = tf.linspace(-1., 1., width) 37 | x_coordinates, y_coordinates = tf.meshgrid(x_linspace, y_linspace) 38 | x_coordinates = tf.reshape(x_coordinates, shape=[-1]) #[H*W] 39 | y_coordinates = tf.reshape(y_coordinates, shape=[-1]) #[H*W] 40 | if ones_flag is None: 41 | indices_grid = tf.stack([x_coordinates, y_coordinates], axis=0) #[2, H*W] 42 | else: 43 | indices_grid = tf.stack([x_coordinates, y_coordinates, tf.ones_like(x_coordinates)], axis=0) 44 | 45 | return indices_grid 46 | 47 | 48 | def interpolate(input, x, y, out_height, out_width, name): 49 | # parameters: input is input image,which has shape of (batchsize,height,width,3) 50 | # x,y is flattened coordinates , which has shape of (16*3500) = 56000 51 | # out_heigth,out_width = height,width 52 | with tf.variable_scope(name): 53 | N, H, W, C = input.get_shape().as_list() #64, 40, 72, 3 54 | 55 | x = tf.cast(x, dtype=tf.float32) 56 | y = tf.cast(y, dtype=tf.float32) 57 | H_f = tf.cast(H, dtype=tf.float32) 58 | W_f = tf.cast(W, dtype=tf.float32) 59 | # note that x,y belongs to [-1,1] before 60 | x = (x + 1.0) * (W_f - 1) * 0.5 # x now is [0,2]*0.5*[width-1],is [0,1]*[width-1] 61 | # shape 16 * 3500 62 | y = (y + 1.0) * (H_f - 1) * 0.5 63 | # get x0 and x1 in bilinear interpolation 64 | x0 = tf.cast(tf.floor(x), tf.int32) # cast to int ,discrete 65 | x1 = x0 + 1 66 | y0 = tf.cast(tf.floor(y), tf.int32) 67 | y1 = y0 + 1 68 | 69 | # clip the coordinate value 70 | max_y = tf.cast(H - 1, dtype=tf.int32) 71 | max_x = tf.cast(W - 1, dtype=tf.int32) 72 | zero = tf.constant([0], shape=(1,), dtype=tf.int32) 73 | 74 | x0 = tf.clip_by_value(x0, zero, max_x) 75 | x1 = tf.clip_by_value(x1, zero, max_x) 76 | y0 = tf.clip_by_value(y0, zero, max_y) 77 | y1 = tf.clip_by_value(y1, zero, max_y) 78 | # note x0,x1,y0,y1 have same shape 16 * 3500 79 | # go to method , look tf.clip_by_value, 80 | # realizing restrict op 81 | flat_image_dimensions = H * W 82 | pixels_batch = tf.range(N) * flat_image_dimensions 83 | # note N is batchsize, pixels_batch has shape [16] 84 | # plus, it's value is [0,1,2,...15]* 3500 85 | flat_output_dimensions = out_height * out_width 86 | # a scalar 87 | base = repeat(pixels_batch, flat_output_dimensions) 88 | # return 16 * 3500, go to see concrete value. 89 | 90 | base_y0 = base + y0 * W 91 | # [0*3500,.....1*3500,....2*3500,....]+[] 92 | base_y1 = base + y1 * W 93 | indices_a = base_y0 + x0 94 | indices_b = base_y1 + x0 95 | indices_c = base_y0 + x1 96 | indices_d = base_y1 + x1 97 | 98 | # gather every pixel value 99 | flat_image = tf.reshape(input, shape=(-1, C)) 100 | flat_image = tf.cast(flat_image, dtype=tf.float32) 101 | 102 | pixel_values_a = tf.gather(flat_image, indices_a) 103 | pixel_values_b = tf.gather(flat_image, indices_b) 104 | pixel_values_c = tf.gather(flat_image, indices_c) 105 | pixel_values_d = tf.gather(flat_image, indices_d) 106 | 107 | x0 = tf.cast(x0, tf.float32) 108 | x1 = tf.cast(x1, tf.float32) 109 | y0 = tf.cast(y0, tf.float32) 110 | y1 = tf.cast(y1, tf.float32) 111 | 112 | area_a = tf.expand_dims(((x1 - x) * (y1 - y)), 1) 113 | area_b = tf.expand_dims(((x1 - x) * (y - y0)), 1) 114 | area_c = tf.expand_dims(((x - x0) * (y1 - y)), 1) 115 | area_d = tf.expand_dims(((x - x0) * (y - y0)), 1) 116 | 117 | output = tf.add_n([area_a * pixel_values_a, 118 | area_b * pixel_values_b, 119 | area_c * pixel_values_c, 120 | area_d * pixel_values_d]) 121 | #for mask the interpolate part which pixel don't move 122 | mask = area_a + area_b + area_c + area_d 123 | output = (1 - mask) * flat_image + mask * output 124 | 125 | return output 126 | 127 | def repeat(x, n_repeats): 128 | # parameters x: list [16] 129 | # n_repeats : scalar,3500 130 | with tf.variable_scope('_repeat'): 131 | rep = tf.reshape(tf.ones(shape=tf.stack([n_repeats, ]), dtype=tf.int32), (1, n_repeats)) 132 | # just know rep has shape (1,3500), and it's value is 1 133 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 134 | # after reshape , matmul is (16,1)X(1,3500) 135 | # in matrix multi, result has shape ( 16,3500) 136 | # plus, in each row i, has same value i * 3500 137 | return tf.reshape(x, [-1]) # return 16* 3500 138 | -------------------------------------------------------------------------------- /tfLib/gp.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | def gradient_penalty(f, real, fake, label, mode): 8 | def _gradient_penalty(f, real, fake=None, label=None): 9 | def _interpolate(a, b=None): 10 | if b is not None: # interpolation in DRAGAN 11 | shape = [tf.shape(a)[0]] + [1] * (a.shape.ndims - 1) 12 | alpha = tf.random_uniform(shape=shape, minval=0., maxval=1.) 13 | inter = a + alpha * (b - a) 14 | inter.set_shape(a.shape) 15 | else: 16 | inter = a 17 | return inter 18 | x = _interpolate(real, fake) 19 | pred = f(x, label) 20 | grad = tf.gradients(tf.reduce_sum(pred), x)[0] 21 | norm = tf.norm(tf.reshape(grad, [tf.shape(grad)[0], -1]), axis=1) 22 | gp = tf.reduce_mean((norm - 1.)**2) 23 | 24 | return gp 25 | 26 | if mode == 'none': 27 | gp = tf.constant(0, dtype=real.dtype) 28 | elif mode == 'Dirac': 29 | gp = _gradient_penalty(f, real, fake=None) 30 | elif mode == 'wgan_gp': 31 | gp = _gradient_penalty(f, real, fake, label) 32 | 33 | return gp 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /tfLib/loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | def getfeature_matching_loss(feature1, feature2): 8 | return tf.reduce_mean(tf.abs( 9 | tf.reduce_mean(feature1, axis=[1, 2]) - tf.reduce_mean(feature2, axis=[1, 2]))) 10 | 11 | def SSCE(logits, labels) : 12 | loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)) 13 | return loss 14 | 15 | def SCE(logits, labels) : 16 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits 17 | (labels=labels, logits=logits)) 18 | return loss 19 | 20 | def cosine(f1, f2): 21 | f1_norm = tf.nn.l2_normalize(f1, dim=0) 22 | f2_norm = tf.nn.l2_normalize(f2, dim=0) 23 | return tf.losses.cosine_distance(f1_norm, f2_norm, dim=0) 24 | 25 | def MSE(i1, i2): 26 | return tf.reduce_mean(tf.square(i1 - i2)) 27 | 28 | def L1(i1, i2): 29 | return tf.reduce_mean(tf.abs(i1 - i2)) 30 | 31 | def TV_loss(i1): 32 | shape = i1.get_shape().as_list() 33 | return tf.reduce_mean(tf.image.total_variation(i1)) / (shape[1]*shape[2]*shape[3]) 34 | -------------------------------------------------------------------------------- /tfLib/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import batch_norm 3 | import functools 4 | import numpy as np 5 | import math 6 | 7 | def log_sum_exp(x, axis=1): 8 | m = tf.reduce_max(x, keep_dims=True) 9 | return m + tf.log(tf.reduce_sum(tf.exp(x - m), axis=axis)) 10 | 11 | def lrelu(x, alpha=0.2, name="LeakyReLU"): 12 | with tf.variable_scope(name): 13 | return tf.maximum(x , alpha*x) * tf.sqrt(2.0) 14 | 15 | def get_weight(shape, gain=1, use_wscale=True, lrmul=1, weight_var='weight'): 16 | 17 | fan_in = np.prod(shape[:-1]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]d 18 | he_std = gain / math.sqrt(int(fan_in)) # He init 19 | # Equalized learning rate and custom learning rate multiplier. 20 | if use_wscale: 21 | init_std = 1.0 / lrmul 22 | runtime_coef = he_std * lrmul 23 | else: 24 | init_std = he_std / lrmul 25 | runtime_coef = lrmul 26 | 27 | # Create variable. 28 | init = tf.initializers.random_normal(0, init_std) 29 | return tf.get_variable(weight_var, shape=shape, initializer=init) * runtime_coef 30 | 31 | 32 | def conv2d(input_, output_dim, k=4, s=2, gain=1, use_wscale=True, lrmul=1, 33 | weight_var='w', padding='SAME', scope="conv2d", use_bias=True): 34 | 35 | assert padding in ['SAME', 'VALID', 'REFLECT'] 36 | with tf.variable_scope(scope): 37 | w = get_weight([k, k, input_.get_shape()[-1], output_dim], gain=gain, use_wscale=use_wscale, lrmul=lrmul, 38 | weight_var=weight_var) 39 | 40 | if padding == 'REFLECT': 41 | input_ = tf.pad(input_, paddings=tf.constant([[0,0], [1,1], [1,1],[0,0]]), mode='REFLECT') 42 | conv = tf.nn.conv2d(input_, w, strides=[1, s, s, 1], padding='VALID') 43 | else: 44 | conv = tf.nn.conv2d(input_, w, strides=[1, s, s, 1], padding=padding) 45 | 46 | if use_bias: 47 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 48 | conv = tf.reshape(tf.nn.bias_add(conv, biases), tf.shape(conv)) 49 | 50 | return conv 51 | 52 | def fully_connect(input_, output_dim, scope=None, use_sp=False, gain=1, use_wscale=True, lrmul=1, weight_war='w', 53 | bias_start=0.0, with_w=False): 54 | 55 | shape = input_.get_shape().as_list() 56 | with tf.variable_scope(scope or "Linear"): 57 | w = get_weight([shape[1], output_dim], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var=weight_war) 58 | bias = tf.get_variable("bias", [output_dim], tf.float32, 59 | initializer=tf.constant_initializer(bias_start)) 60 | if use_sp: 61 | mul = tf.matmul(input_, w) 62 | else: 63 | mul = tf.matmul(input_, w) 64 | if with_w: 65 | return mul + bias, w, bias 66 | else: 67 | return mul + bias 68 | 69 | # def instance_norm(input, scope="instance_norm", affine=True): 70 | # with tf.variable_scope(scope): 71 | # depth = input.get_shape()[-1] 72 | # 73 | # mean, variance = tf.nn.moments(input, axes=[1, 2], keep_dims=True) 74 | # epsilon = 1e-5 75 | # inv = tf.rsqrt(variance + epsilon) 76 | # normalized = (input - mean) * inv 77 | # if affine: 78 | # scale = tf.get_variable("scale", [depth], 79 | # initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)) 80 | # offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0)) 81 | # return scale * normalized + offset 82 | # else: 83 | # return normalized 84 | 85 | def instance_norm(x, scope='instance_norm'): 86 | return tf.contrib.layers.instance_norm(x, 87 | epsilon=1e-05, 88 | center=True, scale=True, 89 | scope=scope) 90 | 91 | def Adaptive_instance_norm(input, beta, gamma, epsilon=1e-5, scope="adaptive_instance_norm"): 92 | 93 | ch = beta.get_shape().as_list()[-1] 94 | with tf.variable_scope(scope): 95 | 96 | mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True) 97 | inv = tf.rsqrt(variance + epsilon) 98 | normalized = (input - mean) * inv 99 | beta = tf.reshape(beta, shape=[-1, 1, 1, ch]) 100 | gamma = tf.reshape(gamma, shape=[-1, 1, 1, ch]) 101 | 102 | return gamma * normalized + beta 103 | 104 | # def modulated_conv2d_layer(input_, style_code, k=3, output_dim=512, padding='SAME', scope='modulated_conv2d'): 105 | # assert k >= 1 and k % 2 == 1 106 | # ful = functools.partial(fully_connect) 107 | # with tf.variable_scope(scope): 108 | # print(input_) 109 | # input_ = tf.transpose(input_, [0, 3, 1, 2]) 110 | # w = tf.get_variable('w', [k, k, input_.get_shape()[1], output_dim], 111 | # initializer=tf.contrib.layers.variance_scaling_initializer()) 112 | # ww = w[np.newaxis] 113 | # fmaps = input_.shape[1].value 114 | # #Modulate 115 | # style = ful(style_code, output_dim=fmaps) 116 | # style = style + 1 117 | # ww = ww * tf.cast(style[:, np.newaxis, np.newaxis, :, np.newaxis], w.dtype) 118 | # # Demodulate 119 | # d = tf.rsqrt(tf.reduce_sum(tf.square(ww), axis=[1,2,3]) + 1e-8) 120 | # ww *= d[:, np.newaxis, np.newaxis, np.newaxis, :] 121 | # ## Reshape/scale output. 122 | # input_ = tf.reshape(input_, [1, -1, input_.shape[2], input_.shape[3]]) 123 | # w = tf.reshape(tf.transpose(ww, [1, 2, 3, 0, 4]), [ww.shape[1], ww.shape[2], ww.shape[3], -1]) 124 | # 125 | # input_ = tf.nn.conv2d(input_, w, strides=[1, 1, 1, 1], data_format='NCHW', padding=padding) 126 | # print(input_) 127 | # # Reshape/scale output 128 | # input_ = tf.reshape(input_, [-1, output_dim, input_.shape[2], input_.shape[3]]) 129 | # input_ = tf.transpose(input_, [0, 2, 3, 1]) 130 | # 131 | # return input_ 132 | 133 | def modulated_conv2d_layer(input_, style_code, k=1, output_dim=512, us=False, gain=1, use_wscale=True, lrmul=1, weight_war='w', 134 | padding='SAME', scope='modulated_conv2d'): 135 | assert k >= 1 and k % 2 == 1 136 | ful = functools.partial(fully_connect) 137 | with tf.variable_scope(scope): 138 | 139 | # input_ = tf.transpose(input_, [0, 3, 1, 2]) 140 | # w = tf.get_variable('w', [k, k, input_.get_shape()[-1], output_dim], 141 | # initializer=tf.contrib.layers.variance_scaling_initializer()) 142 | w = get_weight([k, k, input_.get_shape()[-1], output_dim], gain=gain, use_wscale=use_wscale, lrmul=lrmul, weight_var=weight_war) 143 | ww = w[np.newaxis] 144 | fmaps = input_.shape[-1].value 145 | #Modulate 146 | style = ful(style_code, output_dim=fmaps) 147 | style = style + 1 148 | ww = ww * tf.cast(style[:, np.newaxis, np.newaxis, :, np.newaxis], w.dtype) 149 | # Demodulate 150 | d = tf.rsqrt(tf.reduce_sum(tf.square(ww), axis=[1,2,3]) + 1e-8) 151 | ww *= d[:, np.newaxis, np.newaxis, np.newaxis, :] 152 | 153 | ## Reshape/scale output. 154 | input_ = tf.transpose(input_, [1, 2, 0, 3]) 155 | input_ = tf.reshape(input_, [1, input_.shape[0], input_.shape[1], -1]) 156 | w = tf.reshape(tf.transpose(ww, [1, 2, 3, 0, 4]), [ww.shape[1], ww.shape[2], ww.shape[3], -1]) 157 | if us: 158 | # print("hha", w.shape) 159 | # w = tf.transpose(w, [0, 1, 3, 2]) 160 | # print(w.shape) 161 | # print(input_.shape) 162 | # input_ = tf.nn.conv2d_transpose(input_, w, output_shape=[input_.shape[0], input_.shape[1]*2, 163 | # input_.shape[2]*2, w.shape[-2]], strides=[1, 2, 2, 1]) 164 | # print(input_.shape) 165 | input_ = upscale(input_, scale=2) 166 | input_ = tf.nn.conv2d(input_, w, strides=[1, 1, 1, 1], data_format='NHWC', padding=padding) 167 | else: 168 | input_ = tf.nn.conv2d(input_, w, strides=[1, 1, 1, 1], data_format='NHWC', padding=padding) 169 | # Reshape/scale output 170 | input_ = tf.reshape(input_, [input_.shape[1], input_.shape[2], -1, output_dim]) 171 | 172 | input_ = tf.transpose(input_, [2, 0, 1, 3]) 173 | 174 | return input_ 175 | 176 | def Resblock_Mo_Affline_layers(x_init, o_dim, style_code, noise_strength, us=True, scope='resblock'): 177 | 178 | _, x_init_h, x_init_w, input_ch = x_init.get_shape().as_list() 179 | with tf.variable_scope(scope): 180 | 181 | def shortcut(x): 182 | if us: 183 | x = upscale(x, scale=2) 184 | if input_ch != o_dim: 185 | x = conv2d(x, output_dim=o_dim, k=1, s=1, scope='conv', use_bias=False) 186 | return x 187 | 188 | with tf.variable_scope('res1'): 189 | 190 | x = lrelu(x_init) 191 | # if us: 192 | # x = upscale(x, scale=2) 193 | x = modulated_conv2d_layer(x, style_code, us=us, output_dim=o_dim, scope='mc1') 194 | noise = tf.random_normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype) 195 | x += noise * tf.cast(noise_strength, x.dtype) 196 | 197 | with tf.variable_scope('res2'): 198 | 199 | x = lrelu(x) 200 | x = modulated_conv2d_layer(x, style_code, us=False, output_dim=o_dim, scope='mc2') 201 | noise = tf.random_normal([tf.shape(x)[0], 1, x.shape[2], x.shape[3]], dtype=x.dtype) 202 | x += noise * tf.cast(noise_strength, x.dtype) 203 | 204 | if o_dim != input_ch or us: 205 | x_init = shortcut(x_init) 206 | 207 | return (x + x_init) / tf.sqrt(2.0) 208 | 209 | def Resblock_AdaIn_Affline_layers(x_init, o_dim, style_code, us=True, scope='resblock'): 210 | 211 | input_ch = x_init.get_shape().as_list()[-1] 212 | affline_layers = functools.partial(fully_connect, output_dim=input_ch*2) 213 | affline_layers2 = functools.partial(fully_connect, output_dim=o_dim*2) 214 | 215 | with tf.variable_scope(scope): 216 | 217 | def shortcut(x): 218 | if us: 219 | x = upscale(x, scale=2) 220 | if input_ch != o_dim: 221 | x = conv2d(x, output_dim=o_dim, k=1, s=1, scope='conv', padding='VALID', use_bias=False) 222 | return x 223 | 224 | with tf.variable_scope('res1'): 225 | bg = affline_layers(style_code, scope='fc1') 226 | beta, gamma = bg[:, 0:input_ch], bg[:, input_ch: input_ch*2] 227 | x = Adaptive_instance_norm(x_init, beta=beta, gamma=gamma, scope='AdaIn1') 228 | x = lrelu(x) 229 | if us: 230 | x = upscale(x, scale=2) 231 | x = conv2d(x, o_dim, k=3, s=1, padding='SAME') 232 | 233 | with tf.variable_scope('res2'): 234 | bg = affline_layers2(style_code, scope='fc2') 235 | beta, gamma = bg[:, 0:o_dim], bg[:, o_dim: o_dim*2] 236 | x = Adaptive_instance_norm(x, beta=beta, gamma=gamma, scope='AdaIn2') 237 | x = lrelu(x) 238 | x = conv2d(x, o_dim, k=3, s=1, padding='SAME') 239 | 240 | if o_dim != input_ch or us: 241 | x_init = shortcut(x_init) 242 | 243 | return (x + x_init) / tf.sqrt(2.0) 244 | 245 | def Resblock(x_init, o_dim=256, relu_type="lrelu", padding='REFLECT', use_IN=True, ds=True, scope='resblock'): 246 | 247 | dim = x_init.get_shape().as_list()[-1] 248 | conv1 = functools.partial(conv2d, output_dim=dim, padding=padding, k=3, s=1) 249 | conv2 = functools.partial(conv2d, output_dim=o_dim, padding=padding, k=3, s=1) 250 | In = functools.partial(instance_norm) 251 | 252 | input_ch = x_init.get_shape().as_list()[-1] 253 | with tf.variable_scope(scope): 254 | 255 | def relu(relu_type): 256 | relu_dict = { 257 | "relu": tf.nn.relu, 258 | "lrelu": lrelu 259 | } 260 | return relu_dict[relu_type] 261 | 262 | def shortcut(x): 263 | if input_ch != o_dim: 264 | x = conv2d(x, output_dim=o_dim, k=1, s=1, scope='conv', use_bias=False) 265 | if ds: 266 | x = avgpool2d(x, k=2) 267 | return x 268 | 269 | if use_IN: 270 | x = conv1(relu(relu_type)(In(x_init, scope='bn1')), scope='c1') 271 | if ds: 272 | x = avgpool2d(x, k=2) 273 | x = conv2(relu(relu_type)(In(x, scope='bn2')), scope='c2') 274 | else: 275 | x = conv1(relu(relu_type)(x_init), scope='c1') 276 | if ds: 277 | x = avgpool2d(x, k=2) 278 | x = conv2(relu(relu_type)(x), scope='c2') 279 | 280 | if input_ch != o_dim or ds: 281 | x_init = shortcut(x_init) 282 | 283 | return (x + x_init) / tf.sqrt(2.0) #unit variance 284 | 285 | def de_conv(input_, output_dim, k_h=4, k_w=4, d_h=2, d_w=2, scope="deconv2d", with_w=False): 286 | 287 | with tf.variable_scope(scope): 288 | 289 | w = tf.get_variable('w', [k_h, k_w, output_dim[-1], input_.get_shape()[-1]], dtype=tf.float32, 290 | initializer=tf.contrib.layers.variance_scaling_initializer()) 291 | 292 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_dim, 293 | strides=[1, d_h, d_w, 1]) 294 | 295 | biases = tf.get_variable('biases', [output_dim[-1]], tf.float32, initializer=tf.constant_initializer(0.0)) 296 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 297 | 298 | if with_w: 299 | return deconv, w, biases 300 | else: 301 | return deconv 302 | 303 | def avgpool2d(x, k=2): 304 | return tf.nn.avg_pool(x, ksize=[1, k, k ,1], strides=[1, k, k, 1], padding='SAME') 305 | 306 | def Adaptive_pool2d(x, output_size=1): 307 | input_size = get_conv_shape(x)[-1] 308 | stride = int(input_size / (output_size)) 309 | kernel_size = input_size - (output_size - 1) * stride 310 | return tf.nn.avg_pool(x, ksize=[1, kernel_size, kernel_size, 1], strides=[1, kernel_size, kernel_size, 1], padding='SAME') 311 | 312 | def upscale(x, scale, method='bilinear'): 313 | _, h, w, _ = get_conv_shape(x) 314 | return tf.image.resize(x, size=(h * scale, w * scale), method=method) 315 | 316 | def get_conv_shape(tensor): 317 | shape = int_shape(tensor) 318 | return shape 319 | 320 | def int_shape(tensor): 321 | shape = tensor.get_shape().as_list() 322 | return [num if num is not None else -1 for num in shape] 323 | 324 | def resize_nearest_neighbor(x, new_size): 325 | x = tf.image.resize_nearest_neighbor(x, new_size) 326 | return x 327 | 328 | def conv_cond_concat(x, y): 329 | """Concatenate conditioning vector on feature map axis.""" 330 | x_shapes = x.get_shape() 331 | y_shapes = y.get_shape() 332 | y_reshaped = tf.reshape(y, [y_shapes[0], 1, 1, y_shapes[-1]]) 333 | return tf.concat([x , y_reshaped*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2] , y_shapes[-1]])], 3) 334 | 335 | def batch_normal(input, scope="scope", reuse=False): 336 | return batch_norm(input, epsilon=1e-5, decay=0.9, scale=True, scope=scope, reuse=reuse, fused=True, updates_collections=None) 337 | 338 | def _l2normalize(v, eps=1e-12): 339 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 340 | 341 | def getWeight_Decay(scope='discriminator'): 342 | return tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=scope)) 343 | 344 | def getTrainVariable(vars, scope='discriminator'): 345 | return [var for var in vars if scope in var.name] 346 | 347 | 348 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | from Dataset import CelebA 7 | from SwapAutoEncoderAdaIN import SAE 8 | from config.train_options import TrainOptions 9 | import setproctitle 10 | setproctitle.setproctitle("SAE") 11 | 12 | opt = TrainOptions().parse() 13 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu_id) 14 | 15 | if __name__ == "__main__": 16 | 17 | dataset = CelebA(opt) 18 | sae = SAE(dataset, opt) 19 | sae.build_model() 20 | sae.train() --------------------------------------------------------------------------------