├── README.md ├── problem_unittests.py └── helper.py /README.md: -------------------------------------------------------------------------------- 1 | # DCGAN Project 2 | ## Generate Faces 3 | 4 | Repository to submit the project to [Udacity](https://br.udacity.com/course/deep-learning-nanodegree-foundation--nd101) 5 | -------------------------------------------------------------------------------- /problem_unittests.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from unittest import mock 3 | import tensorflow as tf 4 | 5 | 6 | def test_safe(func): 7 | """ 8 | Isolate tests 9 | """ 10 | def func_wrapper(*args): 11 | with tf.Graph().as_default(): 12 | result = func(*args) 13 | print('Tests Passed') 14 | return result 15 | 16 | return func_wrapper 17 | 18 | 19 | def _assert_tensor_shape(tensor, shape, display_name): 20 | assert tf.assert_rank(tensor, len(shape), message='{} has wrong rank'.format(display_name)) 21 | 22 | tensor_shape = tensor.get_shape().as_list() if len(shape) else [] 23 | 24 | wrong_dimension = [ten_dim for ten_dim, cor_dim in zip(tensor_shape, shape) 25 | if cor_dim is not None and ten_dim != cor_dim] 26 | assert not wrong_dimension, \ 27 | '{} has wrong shape. Found {}'.format(display_name, tensor_shape) 28 | 29 | 30 | def _check_input(tensor, shape, display_name, tf_name=None): 31 | assert tensor.op.type == 'Placeholder', \ 32 | '{} is not a Placeholder.'.format(display_name) 33 | 34 | _assert_tensor_shape(tensor, shape, 'Real Input') 35 | 36 | if tf_name: 37 | assert tensor.name == tf_name, \ 38 | '{} has bad name. Found name {}'.format(display_name, tensor.name) 39 | 40 | 41 | class TmpMock(): 42 | """ 43 | Mock a attribute. Restore attribute when exiting scope. 44 | """ 45 | def __init__(self, module, attrib_name): 46 | self.original_attrib = deepcopy(getattr(module, attrib_name)) 47 | setattr(module, attrib_name, mock.MagicMock()) 48 | self.module = module 49 | self.attrib_name = attrib_name 50 | 51 | def __enter__(self): 52 | return getattr(self.module, self.attrib_name) 53 | 54 | def __exit__(self, type, value, traceback): 55 | setattr(self.module, self.attrib_name, self.original_attrib) 56 | 57 | 58 | @test_safe 59 | def test_model_inputs(model_inputs): 60 | image_width = 28 61 | image_height = 28 62 | image_channels = 3 63 | z_dim = 100 64 | input_real, input_z, learn_rate = model_inputs(image_width, image_height, image_channels, z_dim) 65 | 66 | _check_input(input_real, [None, image_width, image_height, image_channels], 'Real Input') 67 | _check_input(input_z, [None, z_dim], 'Z Input') 68 | _check_input(learn_rate, [], 'Learning Rate') 69 | 70 | 71 | @test_safe 72 | def test_discriminator(discriminator, tf_module): 73 | with TmpMock(tf_module, 'variable_scope') as mock_variable_scope: 74 | image = tf.placeholder(tf.float32, [None, 28, 28, 3]) 75 | 76 | output, logits = discriminator(image) 77 | _assert_tensor_shape(output, [None, 1], 'Discriminator Training(reuse=false) output') 78 | _assert_tensor_shape(logits, [None, 1], 'Discriminator Training(reuse=false) Logits') 79 | assert mock_variable_scope.called,\ 80 | 'tf.variable_scope not called in Discriminator Training(reuse=false)' 81 | assert mock_variable_scope.call_args == mock.call('discriminator', reuse=False), \ 82 | 'tf.variable_scope called with wrong arguments in Discriminator Training(reuse=false)' 83 | 84 | mock_variable_scope.reset_mock() 85 | 86 | output_reuse, logits_reuse = discriminator(image, True) 87 | _assert_tensor_shape(output_reuse, [None, 1], 'Discriminator Inference(reuse=True) output') 88 | _assert_tensor_shape(logits_reuse, [None, 1], 'Discriminator Inference(reuse=True) Logits') 89 | assert mock_variable_scope.called, \ 90 | 'tf.variable_scope not called in Discriminator Inference(reuse=True)' 91 | assert mock_variable_scope.call_args == mock.call('discriminator', reuse=True), \ 92 | 'tf.variable_scope called with wrong arguments in Discriminator Inference(reuse=True)' 93 | 94 | 95 | @test_safe 96 | def test_generator(generator, tf_module): 97 | with TmpMock(tf_module, 'variable_scope') as mock_variable_scope: 98 | z = tf.placeholder(tf.float32, [None, 100]) 99 | out_channel_dim = 5 100 | 101 | output = generator(z, out_channel_dim) 102 | _assert_tensor_shape(output, [None, 28, 28, out_channel_dim], 'Generator output (is_train=True)') 103 | assert mock_variable_scope.called, \ 104 | 'tf.variable_scope not called in Generator Training(reuse=false)' 105 | assert mock_variable_scope.call_args == mock.call('generator', reuse=False), \ 106 | 'tf.variable_scope called with wrong arguments in Generator Training(reuse=false)' 107 | 108 | mock_variable_scope.reset_mock() 109 | output = generator(z, out_channel_dim, False) 110 | _assert_tensor_shape(output, [None, 28, 28, out_channel_dim], 'Generator output (is_train=False)') 111 | assert mock_variable_scope.called, \ 112 | 'tf.variable_scope not called in Generator Inference(reuse=True)' 113 | assert mock_variable_scope.call_args == mock.call('generator', reuse=True), \ 114 | 'tf.variable_scope called with wrong arguments in Generator Inference(reuse=True)' 115 | 116 | 117 | @test_safe 118 | def test_model_loss(model_loss): 119 | out_channel_dim = 4 120 | input_real = tf.placeholder(tf.float32, [None, 28, 28, out_channel_dim]) 121 | input_z = tf.placeholder(tf.float32, [None, 100]) 122 | 123 | d_loss, g_loss = model_loss(input_real, input_z, out_channel_dim) 124 | 125 | _assert_tensor_shape(d_loss, [], 'Discriminator Loss') 126 | _assert_tensor_shape(g_loss, [], 'Generator Loss') 127 | 128 | 129 | @test_safe 130 | def test_model_opt(model_opt, tf_module): 131 | with TmpMock(tf_module, 'trainable_variables') as mock_trainable_variables: 132 | with tf.variable_scope('discriminator'): 133 | discriminator_logits = tf.Variable(tf.zeros([3, 3])) 134 | with tf.variable_scope('generator'): 135 | generator_logits = tf.Variable(tf.zeros([3, 3])) 136 | 137 | mock_trainable_variables.return_value = [discriminator_logits, generator_logits] 138 | d_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 139 | logits=discriminator_logits, 140 | labels=[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]])) 141 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 142 | logits=generator_logits, 143 | labels=[[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]])) 144 | learning_rate = 0.001 145 | beta1 = 0.9 146 | 147 | d_train_opt, g_train_opt = model_opt(d_loss, g_loss, learning_rate, beta1) 148 | assert mock_trainable_variables.called,\ 149 | 'tf.mock_trainable_variables not called' 150 | 151 | 152 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import hashlib 4 | from urllib.request import urlretrieve 5 | import zipfile 6 | import gzip 7 | import shutil 8 | 9 | import numpy as np 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | 14 | def _read32(bytestream): 15 | """ 16 | Read 32-bit integer from bytesteam 17 | :param bytestream: A bytestream 18 | :return: 32-bit integer 19 | """ 20 | dt = np.dtype(np.uint32).newbyteorder('>') 21 | return np.frombuffer(bytestream.read(4), dtype=dt)[0] 22 | 23 | 24 | def _unzip(save_path, _, database_name, data_path): 25 | """ 26 | Unzip wrapper with the same interface as _ungzip 27 | :param save_path: The path of the gzip files 28 | :param database_name: Name of database 29 | :param data_path: Path to extract to 30 | :param _: HACK - Used to have to same interface as _ungzip 31 | """ 32 | print('Extracting {}...'.format(database_name)) 33 | with zipfile.ZipFile(save_path) as zf: 34 | zf.extractall(data_path) 35 | 36 | 37 | def _ungzip(save_path, extract_path, database_name, _): 38 | """ 39 | Unzip a gzip file and extract it to extract_path 40 | :param save_path: The path of the gzip files 41 | :param extract_path: The location to extract the data to 42 | :param database_name: Name of database 43 | :param _: HACK - Used to have to same interface as _unzip 44 | """ 45 | # Get data from save_path 46 | with open(save_path, 'rb') as f: 47 | with gzip.GzipFile(fileobj=f) as bytestream: 48 | magic = _read32(bytestream) 49 | if magic != 2051: 50 | raise ValueError('Invalid magic number {} in file: {}'.format(magic, f.name)) 51 | num_images = _read32(bytestream) 52 | rows = _read32(bytestream) 53 | cols = _read32(bytestream) 54 | buf = bytestream.read(rows * cols * num_images) 55 | data = np.frombuffer(buf, dtype=np.uint8) 56 | data = data.reshape(num_images, rows, cols) 57 | 58 | # Save data to extract_path 59 | for image_i, image in enumerate( 60 | tqdm(data, unit='File', unit_scale=True, miniters=1, desc='Extracting {}'.format(database_name))): 61 | Image.fromarray(image, 'L').save(os.path.join(extract_path, 'image_{}.jpg'.format(image_i))) 62 | 63 | 64 | def get_image(image_path, width, height, mode): 65 | """ 66 | Read image from image_path 67 | :param image_path: Path of image 68 | :param width: Width of image 69 | :param height: Height of image 70 | :param mode: Mode of image 71 | :return: Image data 72 | """ 73 | image = Image.open(image_path) 74 | 75 | if image.size != (width, height): # HACK - Check if image is from the CELEBA dataset 76 | # Remove most pixels that aren't part of a face 77 | face_width = face_height = 108 78 | j = (image.size[0] - face_width) // 2 79 | i = (image.size[1] - face_height) // 2 80 | image = image.crop([j, i, j + face_width, i + face_height]) 81 | image = image.resize([width, height], Image.BILINEAR) 82 | 83 | return np.array(image.convert(mode)) 84 | 85 | 86 | def get_batch(image_files, width, height, mode): 87 | data_batch = np.array( 88 | [get_image(sample_file, width, height, mode) for sample_file in image_files]).astype(np.float32) 89 | 90 | # Make sure the images are in 4 dimensions 91 | if len(data_batch.shape) < 4: 92 | data_batch = data_batch.reshape(data_batch.shape + (1,)) 93 | 94 | return data_batch 95 | 96 | 97 | def images_square_grid(images, mode): 98 | """ 99 | Save images as a square grid 100 | :param images: Images to be used for the grid 101 | :param mode: The mode to use for images 102 | :return: Image of images in a square grid 103 | """ 104 | # Get maximum size for square grid of images 105 | save_size = math.floor(np.sqrt(images.shape[0])) 106 | 107 | # Scale to 0-255 108 | images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8) 109 | 110 | # Put images in a square arrangement 111 | images_in_square = np.reshape( 112 | images[:save_size*save_size], 113 | (save_size, save_size, images.shape[1], images.shape[2], images.shape[3])) 114 | if mode == 'L': 115 | images_in_square = np.squeeze(images_in_square, 4) 116 | 117 | # Combine images to grid image 118 | new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size)) 119 | for col_i, col_images in enumerate(images_in_square): 120 | for image_i, image in enumerate(col_images): 121 | im = Image.fromarray(image, mode) 122 | new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2])) 123 | 124 | return new_im 125 | 126 | 127 | def download_extract(database_name, data_path): 128 | """ 129 | Download and extract database 130 | :param database_name: Database name 131 | """ 132 | DATASET_CELEBA_NAME = 'celeba' 133 | DATASET_MNIST_NAME = 'mnist' 134 | 135 | if database_name == DATASET_CELEBA_NAME: 136 | url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip' 137 | hash_code = '00d2c5bc6d35e252742224ab0c1e8fcb' 138 | extract_path = os.path.join(data_path, 'img_align_celeba') 139 | save_path = os.path.join(data_path, 'celeba.zip') 140 | extract_fn = _unzip 141 | elif database_name == DATASET_MNIST_NAME: 142 | url = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz' 143 | hash_code = 'f68b3c2dcbeaaa9fbdd348bbdeb94873' 144 | extract_path = os.path.join(data_path, 'mnist') 145 | save_path = os.path.join(data_path, 'train-images-idx3-ubyte.gz') 146 | extract_fn = _ungzip 147 | 148 | if os.path.exists(extract_path): 149 | print('Found {} Data'.format(database_name)) 150 | return 151 | 152 | if not os.path.exists(data_path): 153 | os.makedirs(data_path) 154 | 155 | if not os.path.exists(save_path): 156 | with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar: 157 | urlretrieve( 158 | url, 159 | save_path, 160 | pbar.hook) 161 | 162 | assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \ 163 | '{} file is corrupted. Remove the file and try again.'.format(save_path) 164 | 165 | os.makedirs(extract_path) 166 | try: 167 | extract_fn(save_path, extract_path, database_name, data_path) 168 | except Exception as err: 169 | shutil.rmtree(extract_path) # Remove extraction folder if there is an error 170 | raise err 171 | 172 | # Remove compressed data 173 | os.remove(save_path) 174 | 175 | 176 | class Dataset(object): 177 | """ 178 | Dataset 179 | """ 180 | def __init__(self, dataset_name, data_files): 181 | """ 182 | Initalize the class 183 | :param dataset_name: Database name 184 | :param data_files: List of files in the database 185 | """ 186 | DATASET_CELEBA_NAME = 'celeba' 187 | DATASET_MNIST_NAME = 'mnist' 188 | IMAGE_WIDTH = 28 189 | IMAGE_HEIGHT = 28 190 | 191 | if dataset_name == DATASET_CELEBA_NAME: 192 | self.image_mode = 'RGB' 193 | image_channels = 3 194 | 195 | elif dataset_name == DATASET_MNIST_NAME: 196 | self.image_mode = 'L' 197 | image_channels = 1 198 | 199 | self.data_files = data_files 200 | self.shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, image_channels 201 | 202 | def get_batches(self, batch_size): 203 | """ 204 | Generate batches 205 | :param batch_size: Batch Size 206 | :return: Batches of data 207 | """ 208 | IMAGE_MAX_VALUE = 255 209 | 210 | current_index = 0 211 | while current_index + batch_size <= self.shape[0]: 212 | data_batch = get_batch( 213 | self.data_files[current_index:current_index + batch_size], 214 | *self.shape[1:3], 215 | self.image_mode) 216 | 217 | current_index += batch_size 218 | 219 | yield data_batch / IMAGE_MAX_VALUE - 0.5 220 | 221 | 222 | class DLProgress(tqdm): 223 | """ 224 | Handle Progress Bar while Downloading 225 | """ 226 | last_block = 0 227 | 228 | def hook(self, block_num=1, block_size=1, total_size=None): 229 | """ 230 | A hook function that will be called once on establishment of the network connection and 231 | once after each block read thereafter. 232 | :param block_num: A count of blocks transferred so far 233 | :param block_size: Block size in bytes 234 | :param total_size: The total size of the file. This may be -1 on older FTP servers which do not return 235 | a file size in response to a retrieval request. 236 | """ 237 | self.total = total_size 238 | self.update((block_num - self.last_block) * block_size) 239 | self.last_block = block_num 240 | --------------------------------------------------------------------------------