├── README.md ├── example_images.png ├── feed_dict.py ├── make_video.py ├── ops.py ├── progan_v15.py ├── progan_v16.py └── scripts ├── downloader.py └── image_reshape.py /README.md: -------------------------------------------------------------------------------- 1 | # ProGAN 2 | 3 | Implementation of Progressive Generative Adversarial Network based on research done by Tero Karras 4 | 5 | 6 | The model was trained on landscape images collected from Reddit. 7 | 8 | http://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf 9 | 10 | ![generated images](https://github.com/perplexingpegasus/ProGAN/blob/master/example_images.png?raw=true) 11 | -------------------------------------------------------------------------------- /example_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perplexingpegasus/ProGAN/3fda528bfc9d691ad8748f682c58f11144bd49d4/example_images.png -------------------------------------------------------------------------------- /feed_dict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from itertools import cycle 5 | 6 | ''' 7 | FeedDict handles several numpy mem_map arrays of image data saved within the directory. The arrays 8 | should be named in the format "n1_n2.npy" where n1 x n1 is the resolution of the image data in the 9 | array, and n2 is its number used for indexing purposes. Data should be of type np.float32 and scaled 10 | between -1.0 and 1.0. In order to avoid loading unnecessary data into memory, only one mem_map is 11 | loaded at a time. 12 | ''' 13 | 14 | class FeedDict: 15 | 16 | pickle_filename = 'fd_log.pkl' 17 | 18 | def __init__(self, logdir, imgdir, z_length, n_examples, shuffle=True, min_size=4, max_size=1024): 19 | 20 | self.logdir = logdir 21 | self.shuffle = shuffle 22 | self.z_length = z_length 23 | 24 | self.sizes = [2 ** i for i in range( 25 | int(np.log2(min_size)), 26 | int(np.log2(max_size)) + 1 27 | )] 28 | 29 | files = os.listdir(imgdir) 30 | self.arrays = dict() 31 | 32 | for s in [2 ** i for i in range(2, 11)]: 33 | path_list = [] 34 | for f in files: 35 | 36 | if f.startswith('{}_'.format(s)): 37 | path_list.append(os.path.join(imgdir, f)) 38 | 39 | if shuffle: np.random.shuffle(path_list) 40 | self.arrays.update({s: cycle(path_list)}) 41 | 42 | self.z_fixed = self.z_batch(n_examples, z_length) 43 | 44 | self.cur_res = None 45 | self.cur_path = None 46 | self.cur_array = None 47 | self.cur_array_len = 0 48 | self.idx = 0 49 | 50 | @property 51 | def n_sizes(self): return len(self.sizes) 52 | 53 | def __change_res(self, res): 54 | assert res in self.arrays.keys() 55 | self.cur_res = res 56 | self.__change_array() 57 | 58 | def __change_array(self): 59 | new_path = next(self.arrays[self.cur_res]) 60 | print('Loaded new memmap array: {}'.format(new_path)) 61 | if new_path != self.cur_path: 62 | self.cur_path = new_path 63 | self.cur_array = np.load(new_path) 64 | self.cur_array_len = self.cur_array.shape[0] 65 | if self.shuffle: np.random.shuffle(self.cur_array) 66 | self.idx = 0 67 | 68 | def z_batch(self, batch_size, random_state=None): 69 | if random_state is not None: 70 | np.random.seed(random_state) 71 | return np.random.normal(0.0, 1.0, size=[batch_size, self.z_length]) 72 | 73 | def x_batch(self, batch_size, res): 74 | if res != self.cur_res: 75 | self.__change_res(res) 76 | 77 | remaining = self.cur_array_len - self.idx 78 | start = self.idx 79 | 80 | if remaining >= batch_size: 81 | stop = start + batch_size 82 | batch = self.cur_array[start:stop] 83 | 84 | else: 85 | stop = batch_size - remaining 86 | batch = self.cur_array[start:] 87 | self.__change_array() 88 | batch = np.concatenate((batch, self.cur_array[:stop])) 89 | 90 | self.idx = stop 91 | return batch 92 | 93 | @classmethod 94 | def load(cls, logdir, **kwargs): 95 | path = os.path.join(logdir, cls.pickle_filename) 96 | if os.path.exists(path): 97 | with open(path, 'rb') as f: 98 | fd = pickle.load(f) 99 | if type(fd) == cls: 100 | print('Restored feed_dict -------\n') 101 | return fd 102 | return cls(logdir, **kwargs) 103 | 104 | def save(self): 105 | path = os.path.join(self.logdir, self.pickle_filename) 106 | with open(path, 'wb') as f: 107 | pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /make_video.py: -------------------------------------------------------------------------------- 1 | from progan_v15 import ProGAN 2 | 3 | import librosa 4 | import numpy as np 5 | from moviepy.video.VideoClip import VideoClip 6 | from moviepy.editor import AudioFileClip 7 | from sklearn.preprocessing import StandardScaler 8 | 9 | 10 | def get_z_from_audio(audio, z_length, n_bins=60, hop_length=512, random_state=50): 11 | np.random.seed(random_state) 12 | if type(audio) == str: 13 | audio, sr = librosa.load(audio) 14 | 15 | y = librosa.core.cqt(audio, n_bins=n_bins, hop_length=hop_length) 16 | mag, phase = librosa.core.magphase(y) 17 | mag = mag.T 18 | mag = StandardScaler().fit_transform(mag) 19 | 20 | s0, s1 = mag.shape 21 | static = np.random.normal(size=[z_length - s1]) 22 | static = np.tile(static, (s0, 1)) 23 | 24 | z = np.concatenate((mag, static), 1) 25 | z = z.T 26 | np.random.shuffle(z) 27 | z = z.T 28 | return z 29 | 30 | def make_video(audio, filename, progan, n_bins=60, random_state=0, imgs_per_batch=20): 31 | y, sr = librosa.load(audio) 32 | song_length = len(y) / sr 33 | z_audio = get_z_from_audio(y, z_length=progan.z_length, n_bins=n_bins, random_state=random_state) 34 | fps = z_audio.shape[0] / song_length 35 | res = progan.get_cur_res() 36 | shape = (res, res * 16 // 9, 3) 37 | 38 | imgs = np.zeros(shape=[imgs_per_batch, *shape], dtype=np.float32) 39 | 40 | def make_frame(t): 41 | global imgs 42 | cur_frame_idx = int(t * fps) 43 | 44 | if cur_frame_idx >= len(z_audio): 45 | return np.zeros(shape=shape, dtype=np.uint8) 46 | 47 | if cur_frame_idx % imgs_per_batch == 0: 48 | imgs = progan.generate(z_audio[cur_frame_idx:cur_frame_idx + imgs_per_batch]) 49 | imgs = imgs[:, :, :res * 8 // 9, :] 50 | imgs_rev = np.flip(imgs, 2) 51 | imgs = np.concatenate((imgs, imgs_rev), 2) 52 | 53 | return imgs[cur_frame_idx % imgs_per_batch] 54 | 55 | video_clip = VideoClip(make_frame=make_frame, duration=song_length) 56 | audio_clip = AudioFileClip(audio) 57 | video_clip = video_clip.set_audio(audio_clip) 58 | video_clip.write_videofile(filename, fps=fps) 59 | 60 | if __name__ == '__main__': 61 | progan = ProGAN( 62 | logdir='logdir_v2', 63 | imgdir='img_arrays', 64 | ) 65 | make_video('videos\\eco_zones.mp3', 'eco_zones.mp4', progan, random_state=768) -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | weight_init = tf.random_normal_initializer() 5 | bias_init = tf.constant_initializer(0) 6 | 7 | 8 | def conv(input, out_channels, filter_size=3, k=1, padding='SAME', mode=None, output_shape=None): 9 | 10 | in_shape = tf.shape(input) 11 | input_channels = int(input.get_shape()[1]) 12 | 13 | if mode == 'upscale' or mode == 'transpose': 14 | filter_shape = [filter_size, filter_size, out_channels, input_channels] 15 | else: 16 | filter_shape = [filter_size, filter_size, input_channels, out_channels] 17 | 18 | filter = tf.get_variable('filter', filter_shape, initializer=weight_init) 19 | fan_in = float(filter_size ** 2 * input_channels) 20 | filter = filter * tf.sqrt(2.0 / fan_in) 21 | 22 | b = tf.get_variable('bias', [1, out_channels, 1, 1], initializer=bias_init) 23 | 24 | if mode == 'upscale': 25 | filter = tf.pad(filter, [[1, 1], [1, 1], [0, 0], [0, 0]], mode='CONSTANT') 26 | filter = tf.add_n([filter[1:, 1:], filter[:-1, 1:], filter[1:, :-1], filter[:-1, :-1]]) 27 | output_shape = [in_shape[0], out_channels, in_shape[2] * 2, in_shape[3] * 2] 28 | output = tf.nn.conv2d_transpose(input, filter, output_shape, [1, 1, 2, 2], 29 | padding=padding, data_format='NCHW') 30 | 31 | elif mode == 'downscale': 32 | filter = tf.pad(filter, [[1, 1], [1, 1], [0, 0], [0, 0]], mode='CONSTANT') 33 | filter = tf.add_n([filter[1:, 1:], filter[:-1, 1:], filter[1:, :-1], filter[:-1, :-1]]) 34 | filter *= 0.25 35 | output = tf.nn.conv2d(input, filter, [1, 1, 2, 2], padding=padding, data_format='NCHW') 36 | 37 | elif mode == 'transpose': 38 | output = tf.nn.conv2d_transpose(input, filter, output_shape, [1, 1, k, k], 39 | padding=padding, data_format='NCHW') 40 | 41 | else: 42 | output = tf.nn.conv2d(input, filter, [1, 1, k, k], padding=padding, data_format='NCHW') 43 | 44 | output += b 45 | 46 | if out_channels == 1: 47 | output = tf.squeeze(output, 3) 48 | 49 | return output 50 | 51 | 52 | def dense(input, output_size): 53 | fan_in = int(input.get_shape()[1]) 54 | W = tf.get_variable('W', [fan_in, output_size], initializer=weight_init) 55 | W = W * tf.sqrt(2.0 / float(fan_in)) 56 | b = tf.get_variable('b', [1, output_size, 1, 1], initializer=bias_init) 57 | return tf.matmul(input, W) + b 58 | 59 | 60 | def leaky_relu(input, alpha=0.2): 61 | return tf.nn.leaky_relu(input, alpha=alpha) 62 | 63 | 64 | def pixelwise_norm(input): 65 | pixel_var = tf.reduce_mean(tf.square(input), 1, keepdims=True) 66 | return input / tf.sqrt(pixel_var + 1e-8) 67 | 68 | 69 | def g_conv_layer(input, out_channels, **kwargs): 70 | return pixelwise_norm(leaky_relu(conv(input, out_channels, **kwargs))) 71 | 72 | 73 | def d_conv_layer(input, out_channels, **kwargs): 74 | return leaky_relu(conv(input, out_channels, **kwargs)) 75 | 76 | 77 | def minibatch_stddev(input): 78 | shape = tf.shape(input) 79 | group_size = tf.minimum(4, shape[0]) 80 | x = tf.reshape(input, [group_size, -1, shape[1], shape[2], shape[3]]) 81 | 82 | mu = tf.reduce_mean(x, axis=0, keepdims=True) 83 | sigma = tf.sqrt(tf.reduce_mean(tf.square(x - mu), axis=0) + 1e-8) 84 | 85 | sigma_avg = tf.reduce_mean(sigma, axis=[1, 2, 3], keepdims=True) 86 | sigma_avg = tf.tile(sigma_avg, [group_size, 1, shape[2], shape[3]]) 87 | return tf.concat((input, sigma_avg), axis=1) 88 | 89 | 90 | def upscale(input): 91 | shape = tf.shape(input) 92 | channels = input.get_shape()[1] 93 | output = tf.reshape(input, [-1, channels, shape[2], 1, shape[3], 1]) 94 | output = tf.tile(output, [1, 1, 1, 2, 1, 2]) 95 | return tf.reshape(output, [-1, channels, shape[2] * 2, shape[3] * 2]) 96 | 97 | 98 | def downscale(input): 99 | return tf.nn.avg_pool(input, ksize=[1, 1, 2, 2], strides=[1, 1, 2, 2], 100 | padding='SAME', data_format='NCHW') 101 | 102 | 103 | def resize_images(input, dims=None): 104 | if dims is None: 105 | dims = tf.shape(input)[2] * 2, tf.shape(input)[3] * 2 106 | return tf.image.resize_nearest_neighbor(input, dims) 107 | 108 | 109 | def scale_uint8(input): 110 | input = tf.to_float(input) 111 | return (input / 127.5) - 1 112 | 113 | 114 | def tensor_to_imgs(input, switch_dims=True): 115 | if switch_dims: input = tf.transpose(input, (0, 2, 3, 1)) 116 | imgs = tf.minimum(tf.maximum(input, -tf.ones_like(input)), tf.ones_like(input)) 117 | imgs = (imgs + 1) * 127.5 118 | return tf.cast(imgs, tf.uint8) -------------------------------------------------------------------------------- /progan_v15.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime as dt 3 | 4 | # Operations used in building the network. Many are not used in the current model 5 | from ops import * 6 | # FeedDict object used to continuously provide new training data 7 | from feed_dict import FeedDict 8 | 9 | 10 | # TODO: add argparser and flags 11 | # TODO: refactor training function 12 | # TODO: train next version of model using reset_optimizer=True 13 | 14 | 15 | class ProGAN: 16 | def __init__(self, 17 | logdir, # directory of stored models 18 | imgdir, # directory of images for FeedDict 19 | learning_rate=0.001, # Adam optimizer learning rate 20 | beta1=0, # Adam optimizer beta1 21 | beta2=0.99, # Adam optimizer beta2 22 | w_lambda=10.0, # WGAN-GP/LP lambda 23 | w_gamma=1.0, # WGAN-GP/LP gamma 24 | epsilon=0.001, # WGAN-GP/LP lambda 25 | z_length=512, # latent variable size 26 | n_imgs=800000, # number of images to show in each growth step 27 | batch_repeats=1, # number of times to repeat minibatch 28 | n_examples=24, # number of example images to generate 29 | lipschitz_penalty=True, # if True, use WGAN-LP instead of WGAN-GP 30 | big_image=True, # Generate a single large preview image, only works if n_examples = 24 31 | scaling_factor=None, # factor to scale down number of trainable parameters 32 | reset_optimizer=False, # reset optimizer variables with each new layer 33 | ): 34 | 35 | # Scale down the number of factors if scaling_factor is provided 36 | self.channels = [512, 512, 512, 512, 256, 128, 64, 32, 16, 8] 37 | if scaling_factor: 38 | assert scaling_factor > 1 39 | self.channels = [max(4, c // scaling_factor) for c in self.channels] 40 | 41 | self.batch_size = [16, 16, 16, 16, 16, 16, 8, 4, 3] 42 | self.z_length = z_length 43 | self.n_examples = n_examples 44 | self.batch_repeats = batch_repeats if batch_repeats else 1 45 | self.n_imgs = n_imgs 46 | self.logdir = logdir 47 | self.big_image = big_image 48 | self.w_lambda = w_lambda 49 | self.w_gamma = w_gamma 50 | self.epsilon = epsilon 51 | self.reset_optimizer=reset_optimizer 52 | self.lipschitz_penalty = lipschitz_penalty 53 | self.start = True 54 | 55 | # Generate fized latent variables for image previews 56 | np.random.seed(0) 57 | self.z_fixed = np.random.normal(size=[self.n_examples, self.z_length]) 58 | 59 | # Initialize placeholders 60 | self.x_placeholder = tf.placeholder(tf.float32, [None, None, None, 3]) 61 | self.z_placeholder = tf.placeholder(tf.float32, [None, self.z_length]) 62 | 63 | # Global step 64 | with tf.variable_scope('global_step'): 65 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 66 | self.global_step_op = tf.assign(self.global_step, tf.add(self.global_step, 1)) 67 | 68 | # Non-trainable variables for counting to next layer and incrementing value of alpha 69 | with tf.variable_scope('image_count'): 70 | self.total_imgs = tf.Variable(0.0, name='image_step', trainable=False) 71 | self.img_count_placeholder = tf.placeholder(tf.float32) 72 | self.img_step_op = tf.assign(self.total_imgs, 73 | tf.add(self.total_imgs, self.img_count_placeholder)) 74 | 75 | self.img_step = tf.mod(tf.add(self.total_imgs, self.n_imgs), self.n_imgs * 2) 76 | self.alpha = tf.minimum(1.0, tf.div(self.img_step, self.n_imgs)) 77 | self.layer = tf.floor_div(tf.add(self.total_imgs, self.n_imgs), self.n_imgs * 2) 78 | 79 | # Initialize optimizer as member variable if not rest_optimizer, otherwise generate new 80 | # optimizer for each layer 81 | if self.reset_optimizer: 82 | self.lr = learning_rate 83 | self.beta1 = beta1 84 | self.beta2 = beta2 85 | else: 86 | self.g_optimizer = tf.train.AdamOptimizer(learning_rate, beta1, beta2) 87 | self.d_optimizer = tf.train.AdamOptimizer(learning_rate, beta1, beta2) 88 | 89 | # Initialize FeedDict 90 | self.feed = FeedDict.load(imgdir, logdir) 91 | self.n_layers = int(np.log2(1024)) - 1 92 | self.networks = [self._create_network(i + 1) for i in range(self.n_layers)] 93 | 94 | # Initialize Session, FileWriter and Saver 95 | self.sess = tf.Session() 96 | self.sess.run(tf.global_variables_initializer()) 97 | self.writer = tf.summary.FileWriter(self.logdir, graph=self.sess.graph) 98 | self.saver = tf.train.Saver() 99 | 100 | # Look in logdir to see if a saved model already exists. If so, load it 101 | try: 102 | self.saver.restore(self.sess, tf.train.latest_checkpoint(self.logdir)) 103 | print('Restored ----------------\n') 104 | except Exception: 105 | pass 106 | 107 | # Function for fading input of current layer into previous layer based on current value of alpha 108 | def _reparameterize(self, x0, x1): 109 | return tf.add( 110 | tf.scalar_mul(tf.subtract(1.0, self.alpha), x0), 111 | tf.scalar_mul(self.alpha, x1) 112 | ) 113 | 114 | # Function for creating network layout at each layer 115 | def _create_network(self, layers): 116 | 117 | # Build the generator for this layer 118 | def generator(z): 119 | with tf.variable_scope('Generator'): 120 | with tf.variable_scope('latent_vector'): 121 | z = tf.expand_dims(z, 1) 122 | g1 = tf.expand_dims(z, 2) 123 | for i in range(layers): 124 | with tf.variable_scope('layer_{}'.format(i)): 125 | if i > 0: 126 | g1 = resize(g1) 127 | if i == layers - 1 and layers > 1: 128 | g0 = g1 129 | with tf.variable_scope('1'): 130 | if i == 0: 131 | g1 = pixelwise_norm(leaky_relu(conv2d_transpose( 132 | g1, [tf.shape(g1)[0], 4, 4, self.channels[0]]))) 133 | else: 134 | g1 = pixelwise_norm(leaky_relu(conv2d(g1, self.channels[i]))) 135 | with tf.variable_scope('2'): 136 | g1 = pixelwise_norm(leaky_relu(conv2d(g1, self.channels[i]))) 137 | with tf.variable_scope('rgb_layer_{}'.format(layers - 1)): 138 | g1 = conv2d(g1, 3, 1, weight_norm=False) 139 | if layers > 1: 140 | with tf.variable_scope('rgb_layer_{}'.format(layers - 2)): 141 | g0 = conv2d(g0, 3, 1, weight_norm=False) 142 | g = self._reparameterize(g0, g1) 143 | else: 144 | g = g1 145 | return g 146 | 147 | # Build the discriminator for this layer 148 | def discriminator(x): 149 | with tf.variable_scope('Discriminator'): 150 | if layers > 1: 151 | with tf.variable_scope('rgb_layer_{}'.format(layers - 2)): 152 | d0 = avg_pool(x) 153 | d0 = leaky_relu(conv2d(d0, self.channels[layers - 1], 1)) 154 | with tf.variable_scope('rgb_layer_{}'.format(layers - 1)): 155 | d1 = leaky_relu(conv2d(x, self.channels[layers], 1)) 156 | for i in reversed(range(layers)): 157 | with tf.variable_scope('layer_{}'.format(i)): 158 | if i == 0: 159 | d1 = minibatch_stddev(d1) 160 | with tf.variable_scope('1'): 161 | d1 = leaky_relu(conv2d(d1, self.channels[i])) 162 | with tf.variable_scope('2'): 163 | if i == 0: 164 | d1 = leaky_relu(conv2d(d1, self.channels[i], 4, padding='VALID')) 165 | else: 166 | d1 = leaky_relu(conv2d(d1, self.channels[i])) 167 | if i != 0: 168 | d1 = avg_pool(d1) 169 | if i == layers - 1 and layers > 1: 170 | d1 = self._reparameterize(d0, d1) 171 | with tf.variable_scope('dense'): 172 | d = tf.reshape(d1, [-1, self.channels[0]]) 173 | d = dense_layer(d, 1) 174 | return d 175 | 176 | # image dimensions 177 | dim = 2 ** (layers + 1) 178 | 179 | # Build the current network 180 | with tf.variable_scope('Network', reuse=tf.AUTO_REUSE): 181 | Gz = generator(self.z_placeholder) 182 | Dz = discriminator(Gz) 183 | 184 | # Mix different resolutions of input images according to value of alpha 185 | with tf.variable_scope('reshape'): 186 | if layers > 1: 187 | x0 = resize(self.x_placeholder, (dim // 2, dim // 2)) 188 | x0 = resize(x0, (dim, dim)) 189 | x1 = resize(self.x_placeholder, (dim, dim)) 190 | x = self._reparameterize(x0, x1) 191 | else: 192 | x = resize(self.x_placeholder, (dim, dim)) 193 | Dx = discriminator(x) 194 | 195 | # Fake and real image mixing for WGAN-GP loss function 196 | interp = tf.random_uniform(shape=[tf.shape(Dz)[0], 1, 1, 1], minval=0., maxval=1.) 197 | x_hat = interp * x + (1 - interp) * Gz 198 | Dx_hat = discriminator(x_hat) 199 | 200 | # Loss function and scalar summaries 201 | with tf.variable_scope('Loss_Function'): 202 | 203 | # Wasserstein Distance 204 | wd = Dz - Dx 205 | 206 | # Gradient/Lipschitz Penalty 207 | grads = tf.gradients(Dx_hat, [x_hat])[0] 208 | slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), [1, 2, 3])) 209 | if self.lipschitz_penalty: 210 | gp = tf.square(tf.maximum((slopes - self.w_gamma) / self.w_gamma, 0)) 211 | else: 212 | gp = tf.square((slopes - self.w_gamma) / self.w_gamma) 213 | gp_scaled = self.w_lambda * gp 214 | 215 | # Epsilon penalty keeps discriminator output for drifting too far away from zero 216 | epsilon_cost = self.epsilon * tf.square(Dx) 217 | 218 | # Cost and summary scalars 219 | g_cost = tf.reduce_mean(-Dz) 220 | d_cost = tf.reduce_mean(wd + gp_scaled + epsilon_cost) 221 | wd = tf.abs(tf.reduce_mean(wd)) 222 | gp = tf.reduce_mean(gp) 223 | 224 | # Summaries 225 | wd_sum = tf.summary.scalar('Wasserstein_distance_{}x{}'.format(dim, dim), wd) 226 | gp_sum = tf.summary.scalar('gradient_penalty_{}x{}'.format(dim, dim), gp) 227 | 228 | # Collecting variables to be trained by optimizers 229 | g_vars, d_vars = [], [] 230 | var_scopes = ['layer_{}'.format(i) for i in range(layers)] 231 | var_scopes.extend(['dense', 'rgb_layer_{}'.format(layers - 1), 'rgb_layer_{}'.format(layers - 2)]) 232 | for scope in var_scopes: 233 | g_vars.extend(tf.get_collection( 234 | tf.GraphKeys.GLOBAL_VARIABLES, 235 | scope='Network/Generator/{}'.format(scope))) 236 | d_vars.extend(tf.get_collection( 237 | tf.GraphKeys.GLOBAL_VARIABLES, 238 | scope='Network/Discriminator/{}'.format(scope))) 239 | 240 | # Generate optimizer operations 241 | # if self.reset_optimizer is True then initialize a new optimizer for each layer 242 | with tf.variable_scope('Optimize'): 243 | if self.reset_optimizer: 244 | g_train = tf.train.AdamOptimizer( 245 | self.lr, self.beta1, self.beta2, name='G_optimizer_{}'.format(layers - 1)).minimize( 246 | g_cost, var_list=g_vars) 247 | d_train = tf.train.AdamOptimizer( 248 | self.lr, self.beta1, self.beta2, name='D_optimizer_{}'.format(layers - 1)).minimize( 249 | d_cost, var_list=d_vars) 250 | else: 251 | g_train = self.g_optimizer.minimize(g_cost, var_list=g_vars) 252 | d_train = self.d_optimizer.minimize(d_cost, var_list=d_vars) 253 | 254 | # Print variable names to before running model 255 | print([var.name for var in g_vars]) 256 | print([var.name for var in d_vars]) 257 | 258 | # Generate preview images 259 | with tf.variable_scope('image_preview'): 260 | fake_imgs = tf.minimum(tf.maximum(Gz, -tf.ones_like(Gz)), tf.ones_like(Gz)) 261 | real_imgs = x[:min(self.batch_size[layers - 1], 4), :, :, :] 262 | 263 | # Upsize images to normal visibility 264 | if dim < 256: 265 | fake_imgs = resize(fake_imgs, (256, 256)) 266 | real_imgs = resize(real_imgs, (256, 256)) 267 | 268 | # Concatenate images into one large image for preview, only used if 24 preview images are requested 269 | if self.big_image and self.n_examples == 24: 270 | fake_img_list = tf.unstack(fake_imgs, num=24) 271 | fake_img_list = [tf.concat(fake_img_list[6 * i:6 * (i + 1)], 1) for i in range(4)] 272 | fake_imgs = tf.concat(fake_img_list, 0) 273 | fake_imgs = tf.expand_dims(fake_imgs, 0) 274 | 275 | real_img_list = tf.unstack(real_imgs, num=min(self.batch_size[layers - 1], 4)) 276 | real_imgs = tf.concat(real_img_list, 1) 277 | real_imgs = tf.expand_dims(real_imgs, 0) 278 | 279 | # images summaries 280 | fake_img_sum = tf.summary.image('fake{}x{}'.format(dim, dim), fake_imgs, self.n_examples) 281 | real_img_sum = tf.summary.image('real{}x{}'.format(dim, dim), real_imgs, 4) 282 | 283 | return (dim, wd, gp, wd_sum, gp_sum, g_train, d_train, 284 | fake_img_sum, real_img_sum, Gz, discriminator) 285 | 286 | # Summary adding function 287 | def _add_summary(self, string, gs): 288 | self.writer.add_summary(string, gs) 289 | 290 | # Latent variable 'z' generator 291 | def _z(self, batch_size): 292 | return np.random.normal(0.0, 1.0, [batch_size, self.z_length]) 293 | 294 | # Main training function 295 | def train(self): 296 | prev_layer = None 297 | start_time = dt.datetime.now() 298 | total_imgs = self.sess.run(self.total_imgs) 299 | 300 | while total_imgs < (self.n_layers - 0.5) * self.n_imgs * 2: 301 | 302 | # Get current layer, global step, alpha and total number of images used so far 303 | layer, gs, img_step, alpha, total_imgs = self.sess.run([ 304 | self.layer, self.global_step, self.img_step, self.alpha, self.total_imgs]) 305 | layer = int(layer) 306 | 307 | # Global step interval to save model and generate image previews 308 | save_interval = max(1000, 10000 // 2 ** layer) 309 | 310 | # Get network operations and loss functions for current layer 311 | (dim, wd, gp, wd_sum, gp_sum, g_train, d_train, 312 | fake_img_sum, real_img_sum, Gz, discriminator) = self.networks[layer] 313 | 314 | # Get training data and latent variables to store in feed_dict 315 | feed_dict = {self.x_placeholder: self.feed.next_batch(self.batch_size[layer], dim), 316 | self.z_placeholder: self._z(self.batch_size[layer])} 317 | 318 | # Reset start times if a new layer has begun training 319 | if layer != prev_layer: 320 | start_time = dt.datetime.now() 321 | 322 | # Here's where we actually train the model 323 | for _ in range(self.batch_repeats): 324 | self.sess.run(g_train, feed_dict) 325 | self.sess.run(d_train, feed_dict) 326 | 327 | # Get loss values and summaries 328 | wd_, gp_, wd_sum_str, gp_sum_str = self.sess.run([wd, gp, wd_sum, gp_sum], feed_dict) 329 | 330 | # Print current status, loss functions, etc. 331 | percent_done = np.round(img_step * 50 / self.n_imgs, 4) 332 | imgs_done = int(img_step) 333 | cur_layer_imgs = self.n_imgs * 2 334 | if dim == 4: 335 | percent_done = np.round((percent_done - 50) * 2, 4) 336 | imgs_done -= self.n_imgs 337 | cur_layer_imgs //= 2 338 | print('dimensions: {}x{} ---- {}% ---- images: {}/{} ---- alpha: {} ---- global step: {}' 339 | '\nWasserstein distance: {}\ngradient penalty: {}\n'.format( 340 | dim, dim, percent_done, imgs_done, cur_layer_imgs, alpha, gs, wd_, gp_)) 341 | 342 | # Log scalar data every 20 global steps 343 | if gs % 20 == 0: 344 | self._add_summary(wd_sum_str, gs) 345 | self._add_summary(gp_sum_str, gs) 346 | 347 | # Operations to run every save interval 348 | if gs % save_interval == 0: 349 | 350 | # Do not save the model or generate images immediately after loading/preloading 351 | if self.start: 352 | self.start = False 353 | 354 | # Save the model and generate image previews 355 | else: 356 | print('saving and making images...\n') 357 | self.feed.save() 358 | self.saver.save( 359 | self.sess, os.path.join(self.logdir, "model.ckpt"), 360 | global_step=self.global_step) 361 | real_img_sum_str = self.sess.run(real_img_sum, feed_dict) 362 | img_preview_feed_dict = { 363 | self.x_placeholder: feed_dict[self.x_placeholder][:4], 364 | self.z_placeholder: self.z_fixed} 365 | fake_img_sum_str = self.sess.run(fake_img_sum, img_preview_feed_dict) 366 | self._add_summary(fake_img_sum_str, gs) 367 | self._add_summary(real_img_sum_str, gs) 368 | 369 | # Increment image count and global step variables 370 | img_count = self.batch_repeats * self.batch_size[layer] 371 | self.sess.run(self.global_step_op) 372 | self.sess.run(self.img_step_op, {self.img_count_placeholder: img_count}) 373 | 374 | # Calculate and print estimated time remaining 375 | prev_layer = layer 376 | avg_time = (dt.datetime.now() - start_time) / (imgs_done + self.batch_size[layer]) 377 | steps_remaining = cur_layer_imgs - imgs_done 378 | time_reamining = avg_time * steps_remaining 379 | print('est. time remaining on current layer: {}'.format(time_reamining)) 380 | 381 | def get_cur_res(self): 382 | cur_layer = int(self.sess.run(self.layer)) 383 | return 2 ** (2 + cur_layer) 384 | 385 | # Function for generating images from a 1D or 2D array of latent vectors 386 | def generate(self, z): 387 | if len(z.shape) == 1: 388 | z = np.expand_dims(z, 0) 389 | 390 | cur_layer = int(self.sess.run(self.layer)) 391 | G = self.networks[cur_layer][9] 392 | imgs = self.sess.run(G, {self.z_placeholder: z}) 393 | 394 | imgs = np.minimum(imgs, 1.0) 395 | imgs = np.maximum(imgs, -1.0) 396 | imgs = (imgs + 1) * 255 / 2 397 | imgs = np.uint8(imgs) 398 | 399 | if imgs.shape[0] == 1: 400 | imgs = np.squeeze(imgs, 0) 401 | return imgs 402 | 403 | 404 | def transform(self, input_img, n_iter=100000): 405 | with tf.variable_scope('transform'): 406 | global_step = tf.Variable(0, name='transform_global_step', trainable=False) 407 | transform_img = tf.Variable(input_img, name='transform_img', dtype=tf.float32) 408 | 409 | cur_layer = int(self.sess.run(self.layer)) 410 | (dim, wd, gp, wd_sum, gp_sum, g_train, d_train, 411 | ake_img_sum, real_img_sum, Gz, discriminator) = self.networks[cur_layer] 412 | 413 | with tf.variable_scope('Network', reuse=tf.AUTO_REUSE): 414 | with tf.variable_scope('resize'): 415 | jitter = tf.random_uniform([2], -10, 10, tf.int32) 416 | img = tf.manip.roll(transform_img, jitter, [1, 2]) 417 | img = resize(img, (dim, dim)) 418 | Dt = discriminator(img) 419 | 420 | t_cost = tf.reduce_mean(-Dt) 421 | tc_sum = tf.summary.scalar('transform_cost_{}x{}'.format(dim, dim), t_cost) 422 | t_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='transform/transform_img') 423 | t_train = tf.train.AdamOptimizer(0.0001).minimize( 424 | t_cost, var_list=t_vars, global_step=global_step) 425 | transform_img_sum = tf.summary.image('transform', transform_img) 426 | 427 | self.sess.run(tf.global_variables_initializer()) 428 | 429 | for i in range(n_iter): 430 | gs, t_cost_, tc_sum_str, _ = self.sess.run([global_step, t_cost, tc_sum, t_train]) 431 | print('Global step: {}, cost: {}\n\n'.format(gs, t_cost_)) 432 | if i % 20 == 0: 433 | self._add_summary(tc_sum_str, gs) 434 | if i % 1000 == 0: 435 | img_sum_str = self.sess.run(transform_img_sum) 436 | self._add_summary(img_sum_str, gs) 437 | 438 | 439 | if __name__ == '__main__': 440 | 441 | progan = ProGAN( 442 | logdir='logdir_v2', 443 | imgdir='img_arrays', 444 | ) 445 | # progan = ProGAN( 446 | # logdir='logdir_v3', 447 | # imgdir='img_arrays_botanical', 448 | # reset_optimizer=True 449 | # ) 450 | progan.train() -------------------------------------------------------------------------------- /progan_v16.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | import os 3 | 4 | import numpy as np 5 | 6 | # Operations used in building the network. Many are not used in the current model 7 | from ops import * 8 | # FeedDict object used to continuously provide new training data 9 | from feed_dict import FeedDict 10 | 11 | 12 | # TODO: add argparser and flags 13 | 14 | 15 | class ProGAN: 16 | def __init__(self, 17 | logdir, # directory of stored models 18 | imgdir, # directory of images for FeedDict 19 | learning_rate=0.001, # Adam optimizer learning rate 20 | beta1=0, # Adam optimizer beta1 21 | beta2=0.99, # Adam optimizer beta2 22 | w_lambda=10.0, # WGAN-GP/LP lambda 23 | w_gamma=1.0, # WGAN-GP/LP gamma 24 | epsilon=0.001, # WGAN-GP/LP lambda 25 | z_length=512, # latent variable size 26 | n_imgs=800000, # number of images to show in each growth step 27 | batch_repeats=1, # number of times to repeat minibatch 28 | n_examples=24, # number of example images to generate 29 | lipschitz_penalty=True, # if True, use WGAN-LP instead of WGAN-GP 30 | big_image=True, # Generate a single large preview image, only works if n_examples = 24 31 | reset_optimizer=True, # reset optimizer variables with each new layer 32 | batch_sizes=None, 33 | channels=None, 34 | ): 35 | 36 | # Scale down the number of factors if scaling_factor is provided 37 | self.channels = channels if channels else [512, 512, 512, 512, 256, 128, 64, 32, 16, 16] 38 | self.batch_sizes = batch_sizes if batch_sizes else [16, 16, 16, 16, 16, 16, 12, 4, 3] 39 | 40 | self.z_length = z_length 41 | self.n_examples = n_examples 42 | self.batch_repeats = batch_repeats if batch_repeats else 1 43 | self.n_imgs = n_imgs 44 | self.logdir = logdir 45 | self.big_image = big_image 46 | self.w_lambda = w_lambda 47 | self.w_gamma = w_gamma 48 | self.epsilon = epsilon 49 | self.reset_optimizer=reset_optimizer 50 | self.lipschitz_penalty = lipschitz_penalty 51 | 52 | # Initialize FeedDict 53 | self.feed = FeedDict.load(logdir, imgdir=imgdir, z_length=z_length, n_examples=n_examples) 54 | self.n_layers = self.feed.n_sizes 55 | self.max_imgs = (self.n_layers - 0.5) * self.n_imgs * 2 56 | 57 | # Initialize placeholders 58 | self.x_placeholder = tf.placeholder(tf.uint8, [None, 3, None, None]) 59 | self.z_placeholder = tf.placeholder(tf.float32, [None, self.z_length]) 60 | 61 | # Global step 62 | with tf.variable_scope('global_step'): 63 | self.global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int32) 64 | 65 | # Non-trainable variables for counting to next layer and incrementing value of alpha 66 | with tf.variable_scope('image_count'): 67 | self.total_imgs = tf.Variable(0, name='total_images', trainable=False, dtype=tf.int32) 68 | 69 | img_offset = tf.add(self.total_imgs, self.n_imgs) 70 | imgs_per_layer = self.n_imgs * 2 71 | 72 | self.img_step = tf.mod(img_offset, imgs_per_layer) 73 | self.layer = tf.minimum(tf.floor_div(img_offset, imgs_per_layer), self.n_layers - 1) 74 | 75 | fade_in = tf.to_float(self.img_step) / float(self.n_imgs) 76 | self.alpha = tf.minimum(1.0, tf.maximum(0.0, fade_in)) 77 | 78 | # Initialize optimizer as member variable if not rest_optimizer, otherwise generate new 79 | # optimizer for each layer 80 | if self.reset_optimizer: 81 | self.lr = learning_rate 82 | self.beta1 = beta1 83 | self.beta2 = beta2 84 | else: 85 | self.g_optimizer = tf.train.AdamOptimizer(learning_rate, beta1, beta2) 86 | self.d_optimizer = tf.train.AdamOptimizer(learning_rate, beta1, beta2) 87 | self.networks = [self.create_network(i + 1) for i in range(self.n_layers)] 88 | 89 | # Initialize Session, FileWriter and Saver 90 | self.sess = tf.Session() 91 | self.sess.run(tf.global_variables_initializer()) 92 | self.writer = tf.summary.FileWriter(self.logdir, graph=self.sess.graph) 93 | self.saver = tf.train.Saver() 94 | 95 | # Look in logdir to see if a saved model already exists. If so, load it 96 | try: 97 | self.saver.restore(self.sess, tf.train.latest_checkpoint(self.logdir)) 98 | print('Restored model -----------\n') 99 | except Exception: 100 | pass 101 | 102 | 103 | # Function for fading input of current layer into previous layer based on current value of alpha 104 | def reparameterize(self, x0, x1): 105 | return tf.add( 106 | tf.scalar_mul(tf.subtract(1.0, self.alpha), x0), 107 | tf.scalar_mul(self.alpha, x1) 108 | ) 109 | 110 | 111 | # Build a generator for n layers 112 | def generator(self, z, n_layers): 113 | with tf.variable_scope('Generator'): 114 | 115 | with tf.variable_scope('latent_vector'): 116 | z = tf.expand_dims(z, 2) 117 | g1 = tf.expand_dims(z, 3) 118 | 119 | for i in range(n_layers): 120 | with tf.variable_scope('layer_{}'.format(i)): 121 | 122 | if i == n_layers - 1: 123 | g0 = g1 124 | 125 | with tf.variable_scope('1'): 126 | if i == 0: 127 | g1 = g_conv_layer(g1, self.channels[i], 128 | filter_size=4, padding='VALID', mode='transpose', 129 | output_shape=[tf.shape(g1)[0], self.channels[i], 4, 4]) 130 | else: 131 | g1 = g_conv_layer(g1, self.channels[i], mode='upscale') 132 | 133 | with tf.variable_scope('2'): 134 | g1 = g_conv_layer(g1, self.channels[i]) 135 | 136 | with tf.variable_scope('rgb_layer_{}'.format(n_layers - 1)): 137 | g1 = conv(g1, 3, filter_size=1) 138 | 139 | if n_layers > 1: 140 | with tf.variable_scope('rgb_layer_{}'.format(n_layers - 2)): 141 | g0 = conv(g0, 3, filter_size=1) 142 | g0 = upscale(g0) 143 | g = self.reparameterize(g0, g1) 144 | else: 145 | g = g1 146 | 147 | return g 148 | 149 | 150 | # Build a discriminator n layers 151 | def discriminator(self, x, n_layers): 152 | with tf.variable_scope('Discriminator'): 153 | 154 | if n_layers > 1: 155 | with tf.variable_scope('rgb_layer_{}'.format(n_layers - 2)): 156 | d0 = downscale(x) 157 | d0 = d_conv_layer(d0, self.channels[n_layers - 1], filter_size=1) 158 | 159 | with tf.variable_scope('rgb_layer_{}'.format(n_layers - 1)): 160 | d1 = d_conv_layer(x, self.channels[n_layers], filter_size=1) 161 | 162 | for i in reversed(range(n_layers)): 163 | with tf.variable_scope('layer_{}'.format(i)): 164 | 165 | if i == 0: 166 | d1 = minibatch_stddev(d1) 167 | 168 | with tf.variable_scope('1'): 169 | d1 = d_conv_layer(d1, self.channels[i]) 170 | 171 | with tf.variable_scope('2'): 172 | if i == 0: 173 | d1 = d_conv_layer(d1, self.channels[0], 174 | filter_size=4, padding='VALID') 175 | else: 176 | d1 = d_conv_layer(d1, self.channels[i], mode='downscale') 177 | 178 | if i == n_layers - 1 and n_layers > 1: 179 | d1 = self.reparameterize(d0, d1) 180 | 181 | with tf.variable_scope('dense'): 182 | d = tf.reshape(d1, [-1, self.channels[0]]) 183 | d = dense(d, 1) 184 | 185 | return d 186 | 187 | 188 | # Function for creating network layout at each layer 189 | def create_network(self, n_layers): 190 | 191 | # image dimensions 192 | dim = 2 ** (n_layers + 1) 193 | 194 | # Build the current network 195 | with tf.variable_scope('Network', reuse=tf.AUTO_REUSE): 196 | Gz = self.generator(self.z_placeholder, n_layers) 197 | Dz = self.discriminator(Gz, n_layers) 198 | 199 | # Mix different resolutions of input images according to value of alpha 200 | with tf.variable_scope('training_images'): 201 | x = scale_uint8(self.x_placeholder) 202 | if n_layers > 1: 203 | x0 = upscale(downscale(x)) 204 | x1 = x 205 | x = self.reparameterize(x0, x1) 206 | 207 | Dx = self.discriminator(x, n_layers) 208 | 209 | # Fake and real image mixing for WGAN-GP loss function 210 | interp = tf.random_uniform(shape=[tf.shape(Dz)[0], 1, 1, 1], minval=0.0, maxval=1.0) 211 | x_hat = interp * x + (1 - interp) * Gz 212 | Dx_hat = self.discriminator(x_hat, n_layers) 213 | 214 | # Loss function and scalar summaries 215 | with tf.variable_scope('Loss_Function'): 216 | 217 | # Wasserstein Distance 218 | wd = Dz - Dx 219 | 220 | # Gradient/Lipschitz Penalty 221 | grads = tf.gradients(Dx_hat, [x_hat])[0] 222 | slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), [1, 2, 3])) 223 | 224 | if self.lipschitz_penalty: 225 | gp = tf.square(tf.maximum((slopes - self.w_gamma) / self.w_gamma, 0)) 226 | else: 227 | gp = tf.square((slopes - self.w_gamma) / self.w_gamma) 228 | 229 | gp_scaled = self.w_lambda * gp 230 | 231 | # Epsilon penalty keeps discriminator output for drifting too far away from zero 232 | epsilon_cost = self.epsilon * tf.square(Dx) 233 | 234 | # Cost and summary scalars 235 | g_cost = tf.reduce_mean(-Dz) 236 | d_cost = tf.reduce_mean(wd + gp_scaled + epsilon_cost) 237 | wd = tf.abs(tf.reduce_mean(wd)) 238 | gp = tf.reduce_mean(gp) 239 | 240 | # Summaries 241 | wd_sum = tf.summary.scalar('Wasserstein_distance_{}_({}x{})'.format( 242 | n_layers - 1, dim, dim), wd) 243 | gp_sum = tf.summary.scalar('gradient_penalty_{}_({}x{})'.format( 244 | n_layers - 1, dim, dim), gp) 245 | 246 | # Collecting variables to be trained by optimizers 247 | g_vars, d_vars = [], [] 248 | var_scopes = ['layer_{}'.format(i) for i in range(n_layers)] 249 | var_scopes.extend([ 250 | 'dense', 251 | 'rgb_layer_{}'.format(n_layers - 2), 252 | 'rgb_layer_{}'.format(n_layers - 1) 253 | ]) 254 | 255 | for scope in var_scopes: 256 | g_vars.extend(tf.get_collection( 257 | tf.GraphKeys.GLOBAL_VARIABLES, scope='Network/Generator/{}'.format(scope) 258 | )) 259 | d_vars.extend(tf.get_collection( 260 | tf.GraphKeys.GLOBAL_VARIABLES, scope='Network/Discriminator/{}'.format(scope) 261 | )) 262 | 263 | # Generate optimizer operations 264 | # if self.reset_optimizer is True then initialize a new optimizer for each layer 265 | with tf.variable_scope('Optimize'): 266 | if self.reset_optimizer: 267 | g_train = tf.train.AdamOptimizer( 268 | self.lr, self.beta1, self.beta2, name='G_optimizer_{}'.format(n_layers - 1) 269 | ).minimize( 270 | g_cost, var_list=g_vars) 271 | d_train = tf.train.AdamOptimizer( 272 | self.lr, self.beta1, self.beta2, name='D_optimizer_{}'.format(n_layers - 1) 273 | ).minimize( 274 | d_cost, var_list=d_vars, global_step=self.global_step) 275 | 276 | else: 277 | g_train = self.g_optimizer.minimize(g_cost, var_list=g_vars) 278 | d_train = self.d_optimizer.minimize(d_cost, var_list=d_vars, global_step=self.global_step) 279 | 280 | # Increment image count 281 | n_imgs = tf.shape(x)[0] 282 | new_image_count = tf.add(self.total_imgs, n_imgs) 283 | img_step_op = tf.assign(self.total_imgs, new_image_count) 284 | d_train = tf.group(d_train, img_step_op) 285 | 286 | # Print variable names to before running model 287 | print('\nGenerator variables for layer {} ({} x {}):'.format(n_layers - 1, dim, dim)) 288 | print([var.name for var in g_vars]) 289 | print('\nDiscriminator variables for layer {} ({} x {}):'.format(n_layers - 1, dim, dim)) 290 | print([var.name for var in d_vars]) 291 | 292 | # Generate preview images 293 | with tf.variable_scope('image_preview'): 294 | n_real_imgs = min(self.batch_sizes[n_layers - 1], 4) 295 | fake_imgs = tensor_to_imgs(Gz) 296 | real_imgs = tensor_to_imgs(x[:n_real_imgs]) 297 | 298 | # Upsize images to normal visibility 299 | if dim < 256: 300 | fake_imgs = resize_images(fake_imgs, (256, 256)) 301 | real_imgs = resize_images(real_imgs, (256, 256)) 302 | 303 | # Concatenate images into one large image for preview, only used if 24 preview images are requested 304 | if self.big_image and self.n_examples == 24: 305 | fake_img_list = tf.unstack(fake_imgs, num=24) 306 | fake_img_list = [tf.concat(fake_img_list[6 * i:6 * (i + 1)], 1) for i in range(4)] 307 | fake_imgs = tf.concat(fake_img_list, 0) 308 | fake_imgs = tf.expand_dims(fake_imgs, 0) 309 | 310 | real_img_list = tf.unstack(real_imgs, num=n_real_imgs) 311 | real_imgs = tf.concat(real_img_list, 1) 312 | real_imgs = tf.expand_dims(real_imgs, 0) 313 | 314 | # images summaries 315 | fake_img_sum = tf.summary.image('fake{}x{}'.format(dim, dim), fake_imgs, self.n_examples) 316 | real_img_sum = tf.summary.image('real{}x{}'.format(dim, dim), real_imgs, 4) 317 | 318 | return dict( 319 | wd=wd, gp=gp, wd_sum=wd_sum, gp_sum=gp_sum, g_train=g_train, d_train=d_train, 320 | fake_img_sum=fake_img_sum, real_img_sum=real_img_sum, Gz=Gz 321 | ) 322 | 323 | 324 | # Get current layer, global step, alpha and total number of images used so far 325 | def get_global_vars(self): 326 | gs, layer, img_step, alpha, total_imgs = self.sess.run([ 327 | self.global_step, self.layer, self.img_step, self.alpha, self.total_imgs 328 | ]) 329 | if layer == 0: img_step -= self.n_imgs 330 | return gs, layer, img_step, alpha, total_imgs 331 | 332 | 333 | def get_layer_ops(self, layer): 334 | dim = 2 ** (layer + 2) 335 | batch_size = self.batch_sizes[layer] 336 | n_imgs = self.n_imgs 337 | if layer > 0: n_imgs *= 2 338 | 339 | layer_ops = self.networks[layer] 340 | g_train = layer_ops.get('g_train') 341 | d_train = layer_ops.get('d_train') 342 | get_ops = lambda *op_names: [layer_ops.get(name) for name in op_names] 343 | scalar_sum_ops = get_ops('wd', 'gp', 'wd_sum', 'gp_sum') 344 | img_sum_ops = get_ops('fake_img_sum', 'real_img_sum') 345 | 346 | return dim, batch_size, n_imgs, g_train, d_train, scalar_sum_ops, img_sum_ops 347 | 348 | 349 | # Main training function 350 | def train(self, save_interval=80000): 351 | 352 | def get_loop_progress(layer, img_step): 353 | percent_done = img_step / self.n_imgs 354 | if layer > 0: percent_done /= 2 355 | time = dt.datetime.now() 356 | return time, percent_done 357 | 358 | gs, prev_layer, img_step, alpha, total_imgs = self.get_global_vars() 359 | start_time, start_percent_done = get_loop_progress(prev_layer, img_step) 360 | dim, batch_size, n_imgs, g_train, d_train, scalar_sum_ops, img_sum_ops = self.get_layer_ops(prev_layer) 361 | 362 | save_step = (total_imgs // save_interval + 1) * save_interval 363 | 364 | while total_imgs < self.max_imgs: 365 | gs, layer, img_step, alpha, total_imgs = self.get_global_vars() 366 | 367 | # Get network operations and loss functions for current layer 368 | if layer != prev_layer: 369 | start_time, start_percent_done = get_loop_progress(prev_layer, img_step) 370 | dim, batch_size, n_imgs, g_train, d_train, scalar_sum_ops, img_sum_ops = self.get_layer_ops(layer) 371 | 372 | # Get training data and latent variables to store in feed_dict 373 | feed_dict = { 374 | self.x_placeholder: self.feed.x_batch(batch_size, dim), 375 | self.z_placeholder: self.feed.z_batch(batch_size) 376 | } 377 | 378 | # Here's where we actually train the model 379 | for _ in range(self.batch_repeats): 380 | self.sess.run(d_train, feed_dict) 381 | self.sess.run(g_train, feed_dict) 382 | 383 | if gs % 20 == 0: 384 | 385 | # Get loss values and summaries 386 | wd_value, gp_value, wd_sum_str, gp_sum_str = self.sess.run(scalar_sum_ops, feed_dict) 387 | 388 | # Print current status, loss functions, etc. 389 | time, percent_done = get_loop_progress(layer, img_step) 390 | print( 391 | 'dimensions: ({} x {}) ---- {}% ---- images: {}/{} ---- alpha: {} ---- global step: {}' 392 | '\nWasserstein distance: {}\ngradient penalty: {}'.format( 393 | dim, dim, np.round(percent_done * 100, 4), img_step, n_imgs, 394 | np.round(alpha, 4), gs, wd_value, gp_value 395 | )) 396 | 397 | # Calculate and print estimated time remaining 398 | delta_t = time - start_time 399 | time_remaining = delta_t * (1 / (percent_done - start_percent_done + 1e-8) - 1) 400 | print('est. time remaining on layer {}: {}\n'.format(layer, time_remaining)) 401 | 402 | # Log scalar data every 20 global steps 403 | self.writer.add_summary(wd_sum_str, gs) 404 | self.writer.add_summary(gp_sum_str, gs) 405 | 406 | # Operations to run every save interval 407 | if total_imgs > save_step: 408 | save_step += save_interval 409 | 410 | # Save the model and generate image previews 411 | print('\nsaving and making images...\n') 412 | self.saver.save( 413 | self.sess, os.path.join(self.logdir, "model.ckpt"), 414 | global_step=self.global_step 415 | ) 416 | self.feed.save() 417 | 418 | img_preview_feed_dict = { 419 | self.x_placeholder: feed_dict[self.x_placeholder][:4], 420 | self.z_placeholder: self.feed.z_fixed 421 | } 422 | 423 | fake_img_sum_str, real_img_sum_str = self.sess.run( 424 | img_sum_ops, img_preview_feed_dict 425 | ) 426 | self.writer.add_summary(fake_img_sum_str, gs) 427 | self.writer.add_summary(real_img_sum_str, gs) 428 | 429 | prev_layer = layer 430 | 431 | 432 | def get_cur_res(self): 433 | cur_layer = self.sess.run(self.layer) 434 | return 2 ** (2 + cur_layer) 435 | 436 | 437 | def generate(self, z): 438 | solo = z.ndim == 1 439 | if solo: 440 | z = np.expand_dims(z, 0) 441 | 442 | cur_layer = int(self.sess.run(self.layer)) 443 | imgs = self.networks[cur_layer][9] 444 | imgs = self.sess.run(imgs, {self.z_placeholder: z}) 445 | 446 | if solo: 447 | imgs = np.squeeze(imgs, 0) 448 | return imgs 449 | 450 | 451 | if __name__ == '__main__': 452 | # progan = ProGAN(logdir='logdir_v5', imgdir='memmaps') 453 | 454 | # progan = ProGAN(logdir='logdir_v6', imgdir='memmaps', batch_repeats=4) 455 | 456 | progan = ProGAN(logdir='logdir_v8', imgdir='memmaps', batch_repeats=4) 457 | # progan = ProGAN(logdir='logdir_v9', imgdir='memmaps', batch_repeats=4, batch_sizes=[128, 128, 128, 64, 32, 16, 12, 8, 4]) 458 | 459 | progan.train() -------------------------------------------------------------------------------- /scripts/downloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | 4 | from selenium import webdriver 5 | from selenium.webdriver.common.by import By 6 | from selenium.webdriver.support.ui import WebDriverWait 7 | from selenium.webdriver.support import expected_conditions as EC 8 | 9 | 10 | subreddit = input('Enter subreddit name: ') 11 | save_dir = input('Enter name of folder to save images in: ') 12 | 13 | if not os.path.isdir(save_dir): 14 | os.makedirs(save_dir) 15 | 16 | pages = 100 17 | img_n = 0 18 | browser = webdriver.Firefox() 19 | browser.get('https://old.reddit.com/r/{}'.format(subreddit)) 20 | 21 | for i in range(pages): 22 | icons = WebDriverWait(browser, 300).until( 23 | EC.presence_of_all_elements_located( 24 | (By.CLASS_NAME, "expando-button") 25 | ) 26 | ) 27 | 28 | for icon in icons: 29 | icon.click() 30 | 31 | links = WebDriverWait(browser, 300).until( 32 | EC.presence_of_all_elements_located((By.CLASS_NAME, "may-blank")) 33 | ) 34 | links = list(set([a.get_attribute('href') for a in links if a.get_attribute('href').endswith('.jpg')])) 35 | 36 | for link in links: 37 | image = requests.get(link) 38 | with open('{}/img_{}.jpg'.format(save_dir, img_n), 'wb') as f: 39 | f.write(image.content) 40 | img_n += 1 41 | 42 | if i != pages - 1: 43 | next_button = WebDriverWait(browser, 300).until( 44 | EC.presence_of_element_located((By.CLASS_NAME, "next-button")) 45 | ) 46 | next_button.click() 47 | 48 | print('page: {}, images: {}'.format(i, len(links))) -------------------------------------------------------------------------------- /scripts/image_reshape.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | def generate_square_crops(imgdir, savedir, crops_per_img=10, max_size=1024, filter=Image.BICUBIC): 7 | 8 | img_files = [os.path.join(imgdir, f) for f in os.listdir(imgdir)] 9 | savedir = os.path.join(savedir, '_temp') 10 | if not os.path.exists(savedir): os.makedirs(savedir) 11 | 12 | for i, f in enumerate(img_files): 13 | 14 | with Image.open(f) as img: 15 | width, height = img.size 16 | 17 | if width < max_size or height < max_size: continue 18 | 19 | landscape = width > height 20 | if landscape: 21 | new_height = max_size 22 | new_width = int(width * (max_size / height)) 23 | offset = int(max_size * (width / height - 1) + 1) 24 | else: 25 | new_width = max_size 26 | new_height = int(height * (max_size / width)) 27 | offset = int(max_size * (height / width - 1) + 1) 28 | 29 | n_crops = min(offset, crops_per_img) 30 | window_slide_len = offset / n_crops 31 | 32 | try: 33 | img = img.convert('RGB') 34 | img = img.resize((new_width, new_height), filter) 35 | 36 | for j in range(n_crops): 37 | shift = int(j * window_slide_len) 38 | 39 | if landscape: window = (shift, 0, max_size + shift, max_size) 40 | else: window = (0, shift, max_size, max_size + shift) 41 | 42 | cropped_img = img.crop(window) 43 | mirror_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) 44 | 45 | path = os.path.join(savedir, 'img_{}_{}.jpg'.format(i, j)) 46 | mirror_path = os.path.join(savedir, 'img_{}_{}_mirror.jpg'.format(i, j)) 47 | cropped_img.save(path, "JPEG") 48 | mirror_img.save(mirror_path, "JPEG") 49 | 50 | print('Processed {}\n'.format(f)) 51 | 52 | except OSError: 53 | continue 54 | 55 | 56 | def resize(savedir, NCHW=True, min_size=4, max_size=1024, max_mem=0.8, 57 | use_uint8=True, filter=Image.BICUBIC): 58 | 59 | resized_img_dir = os.path.join(savedir, '_temp') 60 | img_files = [os.path.join(resized_img_dir, f) for f in os.listdir(resized_img_dir)] 61 | np.random.shuffle(img_files) 62 | savedir = os.path.join(savedir, 'memmaps') 63 | if not os.path.exists(savedir): os.makedirs(savedir) 64 | 65 | sizes = [ 66 | 2 ** i for i in range( 67 | int(np.log2(min_size)), 68 | int(np.log2(max_size)) + 1 69 | )] 70 | 71 | pixel_bytes = 3 if use_uint8 else 12 72 | max_bytes = max_mem * 1e9 73 | 74 | for s in sizes: 75 | max_imgs = int(max_bytes / (pixel_bytes * s ** 2)) 76 | batch_shape = (max_imgs, 3, s, s) if NCHW else (max_imgs, s, s, 3) 77 | batch = np.zeros(batch_shape, np.uint8) 78 | img_count = 0 79 | batch_count = 0 80 | 81 | for f in img_files: 82 | 83 | with Image.open(f) as img: 84 | width, height = img.size 85 | 86 | if width != s and height != s: 87 | img = img.resize((s, s), filter) 88 | img = np.asarray(img, np.uint8) 89 | if NCHW: 90 | img = np.transpose(img, (2, 0, 1)) 91 | batch[img_count] = img 92 | 93 | if img_count < max_imgs - 1: 94 | img_count += 1 95 | else: 96 | path = os.path.join(savedir, '{}_{}.npy'.format(s, batch_count)) 97 | np.save(path, batch) 98 | print('Saved {}'.format(path)) 99 | img_count = 0 100 | batch_count += 1 101 | 102 | if img_count != 0: 103 | path = os.path.join(savedir, '{}_{}.npy'.format(s, batch_count)) 104 | np.save(path, batch[:img_count]) 105 | print('Saved {}'.format(path)) 106 | 107 | 108 | if __name__ == '__main__': 109 | imgdir = input('Image directory: ') 110 | savedir = input('Memmap directory: ') 111 | 112 | #generate_square_crops(imgdir, savedir) 113 | resize(savedir) --------------------------------------------------------------------------------