├── 0.jpg ├── LICENSE ├── README.md └── h5tool.py /0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/willylulu/celeba-hq-modified/HEAD/0.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 willylulu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # celeba-hq-modified 2 | 3 | ![imga](https://raw.githubusercontent.com/willylulu/celeba-hq-modified/master/0.jpg) 4 | 5 | Changes I made (compared to original repo, see fork origin): 6 | - Added multithreading support for substantial speedup. 7 | - Save images using original filename (instead of index). 8 | 9 | A modified approach to generate __CelebA-HQ Dataset__ 10 | 11 | The CelebA-HQ is a dataset introduced in _Progressive Growing of GANs for Improved Quality_ ([progressive_growing_of_gans](https://github.com/tkarras/progressive_growing_of_gans)), containing 30,000 high quality images from CelebA. 12 | 13 | The images are originally stored as _HDF5 format (.h5)_, which are not suitable for common data loaders. Therefore, I modified the `h5tool.py` to generate and save CelebA-HQ images in _JPEG format (.jpg)_. 14 | 15 | 16 | ## Usage 17 | 18 | 1. Clone the original repository 19 | 20 | `git clone https://github.com/tkarras/progressive_growing_of_gans/tree/original-theano-version` 21 | 22 | 2. Clone this repository 23 | 24 | `git clone https://github.com/willylulu/celeb-hq-modified` 25 | 26 | 3. Replace `h5tool.py` in the original repo with the one in this repo 27 | 28 | `cp celeb-hq-modified/h5tool.py progressive_growing_of_gans/h5tool` 29 | 30 | 4. Create target directory in the original repository 31 | 32 | ``` 33 | cd progressive_growing_of_gans 34 | mkdir celeba-hq 35 | cd celeba-hq 36 | mkdir celeba-64 37 | mkdir celeba-128 38 | mkdir celeba-256 39 | mkdir celeba-512 40 | mkdir celeba-1024 41 | ``` 42 | 43 | 5. Go back to home path 44 | 45 | `cd` 46 | 47 | 6. Create a directory B and download [CelebA **non-aligned version**](https://drive.google.com/open?id=0B7EVK8r0v71peklHb0pGdDl6R28) and put them in directory A 48 | 49 | 7. Create a directory A and download [CelebA-HQ zip file](https://drive.google.com/drive/folders/0B4qLcYyJmiz0TXY1NG02bzZVRGs) and put them in directory B 50 | 51 | 8. Download [annotation files](https://drive.google.com/open?id=0B7EVK8r0v71pOC0wOVZlQnFfaGs) and put them in directory B 52 | 9. Execute `h5tool.py` 53 | `python h5tool.py create_celeba_hq 123456.h5 ` 54 | 55 | ## Reference 56 | [tkarras/progressive_growing_of_gans](https://github.com/tkarras/progressive_growing_of_gans) 57 | 58 | [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) 59 | -------------------------------------------------------------------------------- /h5tool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | import os 9 | import sys 10 | import io 11 | import glob 12 | import pickle 13 | import argparse 14 | import threading 15 | import Queue 16 | import traceback 17 | import numpy as np 18 | import scipy.ndimage 19 | import PIL.Image 20 | import h5py # conda install h5py 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | class HDF5Exporter: 25 | def __init__(self, h5_filename, resolution, channels=3): 26 | rlog2 = int(np.floor(np.log2(resolution))) 27 | assert resolution == 2 ** rlog2 28 | self.resolution = resolution 29 | self.channels = channels 30 | self.h5_file = h5py.File(h5_filename, 'w') 31 | self.h5_lods = [] 32 | self.buffers = [] 33 | self.buffer_sizes = [] 34 | for lod in xrange(rlog2, -1, -1): 35 | r = 2 ** lod; c = channels 36 | bytes_per_item = c * (r ** 2) 37 | chunk_size = int(np.ceil(128.0 / bytes_per_item)) 38 | buffer_size = int(np.ceil(512.0 * np.exp2(20) / bytes_per_item)) 39 | lod = self.h5_file.create_dataset('data%dx%d' % (r,r), shape=(0,c,r,r), dtype=np.uint8, 40 | maxshape=(None,c,r,r), chunks=(chunk_size,c,r,r), compression='gzip', compression_opts=4) 41 | self.h5_lods.append(lod) 42 | self.buffers.append(np.zeros((buffer_size,c,r,r), dtype=np.uint8)) 43 | self.buffer_sizes.append(0) 44 | 45 | def close(self): 46 | for lod in xrange(len(self.h5_lods)): 47 | self.flush_lod(lod) 48 | self.h5_file.close() 49 | 50 | def add_images(self, img): 51 | assert img.ndim == 4 and img.shape[1] == self.channels and img.shape[2] == img.shape[3] 52 | assert img.shape[2] >= self.resolution and img.shape[2] == 2 ** int(np.floor(np.log2(img.shape[2]))) 53 | for lod in xrange(len(self.h5_lods)): 54 | while img.shape[2] > self.resolution / (2 ** lod): 55 | img = img.astype(np.float32) 56 | img = (img[:, :, 0::2, 0::2] + img[:, :, 0::2, 1::2] + img[:, :, 1::2, 0::2] + img[:, :, 1::2, 1::2]) * 0.25 57 | quant = np.uint8(np.clip(np.round(img), 0, 255)) 58 | ofs = 0 59 | while ofs < quant.shape[0]: 60 | num = min(quant.shape[0] - ofs, self.buffers[lod].shape[0] - self.buffer_sizes[lod]) 61 | self.buffers[lod][self.buffer_sizes[lod] : self.buffer_sizes[lod] + num] = quant[ofs : ofs + num] 62 | self.buffer_sizes[lod] += num 63 | if self.buffer_sizes[lod] == self.buffers[lod].shape[0]: 64 | self.flush_lod(lod) 65 | ofs += num 66 | 67 | def num_images(self): 68 | return self.h5_lods[0].shape[0] + self.buffer_sizes[0] 69 | 70 | def flush_lod(self, lod): 71 | num = self.buffer_sizes[lod] 72 | if num > 0: 73 | self.h5_lods[lod].resize(self.h5_lods[lod].shape[0] + num, axis=0) 74 | self.h5_lods[lod][-num:] = self.buffers[lod][:num] 75 | self.buffer_sizes[lod] = 0 76 | 77 | #---------------------------------------------------------------------------- 78 | 79 | class ExceptionInfo(object): 80 | def __init__(self): 81 | self.type, self.value = sys.exc_info()[:2] 82 | self.traceback = traceback.format_exc() 83 | 84 | #---------------------------------------------------------------------------- 85 | 86 | class WorkerThread(threading.Thread): 87 | def __init__(self, task_queue): 88 | threading.Thread.__init__(self) 89 | self.task_queue = task_queue 90 | 91 | def run(self): 92 | while True: 93 | func, args, result_queue = self.task_queue.get() 94 | if func is None: 95 | break 96 | try: 97 | result = func(*args) 98 | except: 99 | result = ExceptionInfo() 100 | result_queue.put((result, args)) 101 | 102 | #---------------------------------------------------------------------------- 103 | 104 | class ThreadPool(object): 105 | def __init__(self, num_threads): 106 | assert num_threads >= 1 107 | self.task_queue = Queue.Queue() 108 | self.result_queues = dict() 109 | self.num_threads = num_threads 110 | for idx in xrange(self.num_threads): 111 | thread = WorkerThread(self.task_queue) 112 | thread.daemon = True 113 | thread.start() 114 | 115 | def add_task(self, func, args=()): 116 | assert hasattr(func, '__call__') # must be a function 117 | if func not in self.result_queues: 118 | self.result_queues[func] = Queue.Queue() 119 | self.task_queue.put((func, args, self.result_queues[func])) 120 | 121 | def get_result(self, func, verbose_exceptions=True): # returns (result, args) 122 | result, args = self.result_queues[func].get() 123 | if isinstance(result, ExceptionInfo): 124 | if verbose_exceptions: 125 | print '\n\nWorker thread caught an exception:\n' + result.traceback + '\n', 126 | raise result.type, result.value 127 | return result, args 128 | 129 | def finish(self): 130 | for idx in xrange(self.num_threads): 131 | self.task_queue.put((None, (), None)) 132 | 133 | def __enter__(self): # for 'with' statement 134 | return self 135 | 136 | def __exit__(self, *excinfo): 137 | self.finish() 138 | 139 | def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, post_func=lambda x: x, max_items_in_flight=None): 140 | if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4 141 | assert max_items_in_flight >= 1 142 | results = [] 143 | retire_idx = [0] 144 | 145 | def task_func(prepared, idx): 146 | return process_func(prepared) 147 | 148 | def retire_result(): 149 | processed, (prepared, idx) = self.get_result(task_func) 150 | results[idx] = processed 151 | while retire_idx[0] < len(results) and results[retire_idx[0]] is not None: 152 | yield post_func(results[retire_idx[0]]) 153 | results[retire_idx[0]] = None 154 | retire_idx[0] += 1 155 | 156 | for idx, item in enumerate(item_iterator): 157 | prepared = pre_func(item) 158 | results.append(None) 159 | self.add_task(func=task_func, args=(prepared, idx)) 160 | while retire_idx[0] < idx - max_items_in_flight + 2: 161 | for res in retire_result(): yield res 162 | while retire_idx[0] < len(results): 163 | for res in retire_result(): yield res 164 | 165 | #---------------------------------------------------------------------------- 166 | 167 | def inspect(h5_filename): 168 | print '%-20s%s' % ('HDF5 filename', h5_filename) 169 | file_size = os.stat(h5_filename).st_size 170 | print '%-20s%.2f GB' % ('Total size', float(file_size) / np.exp2(30)) 171 | 172 | h5 = h5py.File(h5_filename, 'r') 173 | lods = sorted([value for key, value in h5.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3]) 174 | shapes = [lod.shape for lod in lods] 175 | shape = shapes[0] 176 | h5.close() 177 | print '%-20s%d' % ('Total images', shape[0]) 178 | print '%-20s%dx%d' % ('Resolution', shape[3], shape[2]) 179 | print '%-20s%d' % ('Color channels', shape[1]) 180 | print '%-20s%.2f KB' % ('Size per image', float(file_size) / shape[0] / np.exp2(10)) 181 | 182 | if len(lods) != int(np.log2(shape[3])) + 1: 183 | print 'Warning: The HDF5 file contains incorrect number of LODs' 184 | if any(s[0] != shape[0] for s in shapes): 185 | print 'Warning: The HDF5 file contains inconsistent number of images in different LODs' 186 | print 'Perhaps the dataset creation script was terminated abruptly?' 187 | 188 | #---------------------------------------------------------------------------- 189 | 190 | def compare(first_h5, second_h5): 191 | print 'Comparing %s vs. %s' % (first_h5, second_h5) 192 | h5_a = h5py.File(first_h5, 'r') 193 | h5_b = h5py.File(second_h5, 'r') 194 | lods_a = sorted([value for key, value in h5_a.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3]) 195 | lods_b = sorted([value for key, value in h5_b.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3]) 196 | shape_a = lods_a[0].shape 197 | shape_b = lods_b[0].shape 198 | 199 | if shape_a[1] != shape_b[1]: 200 | print 'The datasets have different number of color channels: %d vs. %d' % (shape_a[1], shape_b[1]) 201 | elif shape_a[3] != shape_b[3] or shape_a[2] != shape_b[2]: 202 | print 'The datasets have different resolution: %dx%d vs. %dx%d' % (shape_a[3], shape_a[2], shape_b[3], shape_b[2]) 203 | else: 204 | min_images = min(shape_a[0], shape_b[0]) 205 | num_diffs = 0 206 | for idx in xrange(min_images): 207 | print '%d / %d\r' % (idx, min_images), 208 | if np.any(lods_a[0][idx] != lods_b[0][idx]): 209 | print '%-40s\r' % '', 210 | print 'Different image: %d' % idx 211 | num_diffs += 1 212 | if shape_a[0] != shape_b[0]: 213 | print 'The datasets contain different number of images: %d vs. %d' % (shape_a[0], shape_b[0]) 214 | if num_diffs == 0: 215 | print 'All %d images are identical.' % min_images 216 | else: 217 | print '%d images out of %d are different.' % (num_diffs, min_images) 218 | 219 | h5_a.close() 220 | h5_b.close() 221 | 222 | #---------------------------------------------------------------------------- 223 | 224 | def display(h5_filename, start=None, stop=None, step=None): 225 | print 'Displaying images from %s' % h5_filename 226 | h5 = h5py.File(h5_filename, 'r') 227 | lods = sorted([value for key, value in h5.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3]) 228 | indices = range(lods[0].shape[0]) 229 | indices = indices[start : stop : step] 230 | 231 | import cv2 # pip install opencv-python 232 | window_name = 'h5tool' 233 | cv2.namedWindow(window_name) 234 | print 'Press SPACE or ENTER to advance, ESC to exit.' 235 | 236 | for idx in indices: 237 | print '%d / %d\r' % (idx, lods[0].shape[0]), 238 | img = lods[0][idx] 239 | img = img.transpose(1, 2, 0) # CHW => HWC 240 | img = img[:, :, ::-1] # RGB => BGR 241 | cv2.imshow(window_name, img) 242 | c = cv2.waitKey() 243 | if c == 27: 244 | break 245 | 246 | h5.close() 247 | print '%-40s\r' % '', 248 | print 'Done.' 249 | 250 | #---------------------------------------------------------------------------- 251 | 252 | def extract(h5_filename, output_dir, start=None, stop=None, step=None): 253 | print 'Extracting images from %s to %s' % (h5_filename, output_dir) 254 | h5 = h5py.File(h5_filename, 'r') 255 | lods = sorted([value for key, value in h5.iteritems() if key.startswith('data')], key=lambda lod: -lod.shape[3]) 256 | shape = lods[0].shape 257 | indices = range(shape[0])[start : stop : step] 258 | if not os.path.isdir(output_dir): 259 | os.makedirs(output_dir) 260 | 261 | for idx in indices: 262 | print '%d / %d\r' % (idx, shape[0]), 263 | img = lods[0][idx] 264 | if img.shape[0] == 1: 265 | img = PIL.Image.fromarray(img[0], 'L') 266 | else: 267 | img = PIL.Image.fromarray(img.transpose(1, 2, 0), 'RGB') 268 | img.save(os.path.join(output_dir, 'img%08d.png' % idx)) 269 | 270 | h5.close() 271 | print '%-40s\r' % '', 272 | print 'Extracted %d images.' % len(indices) 273 | 274 | #---------------------------------------------------------------------------- 275 | 276 | def create_custom(h5_filename, image_dir): 277 | print 'Creating custom dataset %s from %s' % (h5_filename, image_dir) 278 | glob_pattern = os.path.join(image_dir, '*') 279 | image_filenames = sorted(glob.glob(glob_pattern)) 280 | if len(image_filenames) == 0: 281 | print 'Error: No input images found in %s' % glob_pattern 282 | return 283 | 284 | img = np.asarray(PIL.Image.open(image_filenames[0])) 285 | resolution = img.shape[0] 286 | channels = img.shape[2] if img.ndim == 3 else 1 287 | if img.shape[1] != resolution: 288 | print 'Error: Input images must have the same width and height' 289 | return 290 | if resolution != 2 ** int(np.floor(np.log2(resolution))): 291 | print 'Error: Input image resolution must be a power-of-two' 292 | return 293 | if channels not in [1, 3]: 294 | print 'Error: Input images must be stored as RGB or grayscale' 295 | 296 | h5 = HDF5Exporter(h5_filename, resolution, channels) 297 | for idx in xrange(len(image_filenames)): 298 | print '%d / %d\r' % (idx, len(image_filenames)), 299 | img = np.asarray(PIL.Image.open(image_filenames[idx])) 300 | if channels == 1: 301 | img = img[np.newaxis, :, :] # HW => CHW 302 | else: 303 | img = img.transpose(2, 0, 1) # HWC => CHW 304 | h5.add_images(img[np.newaxis]) 305 | 306 | print '%-40s\r' % 'Flushing data...', 307 | h5.close() 308 | print '%-40s\r' % '', 309 | print 'Added %d images.' % len(image_filenames) 310 | 311 | #---------------------------------------------------------------------------- 312 | 313 | def create_mnist(h5_filename, mnist_dir, export_labels=False): 314 | print 'Loading MNIST data from %s' % mnist_dir 315 | import gzip 316 | with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: 317 | images = np.frombuffer(file.read(), np.uint8, offset=16) 318 | with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file: 319 | labels = np.frombuffer(file.read(), np.uint8, offset=8) 320 | images = images.reshape(-1, 1, 28, 28) 321 | images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0) 322 | assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8 323 | assert labels.shape == (60000,) and labels.dtype == np.uint8 324 | assert np.min(images) == 0 and np.max(images) == 255 325 | assert np.min(labels) == 0 and np.max(labels) == 9 326 | 327 | print 'Creating %s' % h5_filename 328 | h5 = HDF5Exporter(h5_filename, 32, 1) 329 | h5.add_images(images) 330 | h5.close() 331 | 332 | if export_labels: 333 | npy_filename = os.path.splitext(h5_filename)[0] + '-labels.npy' 334 | print 'Creating %s' % npy_filename 335 | onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) 336 | onehot[np.arange(labels.size), labels] = 1.0 337 | np.save(npy_filename, onehot) 338 | print 'Added %d images.' % images.shape[0] 339 | 340 | #---------------------------------------------------------------------------- 341 | 342 | def create_mnist_rgb(h5_filename, mnist_dir, num_images=1000000, random_seed=123): 343 | print 'Loading MNIST data from %s' % mnist_dir 344 | import gzip 345 | with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file: 346 | images = np.frombuffer(file.read(), np.uint8, offset=16) 347 | images = images.reshape(-1, 28, 28) 348 | images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0) 349 | assert images.shape == (60000, 32, 32) and images.dtype == np.uint8 350 | assert np.min(images) == 0 and np.max(images) == 255 351 | 352 | print 'Creating %s' % h5_filename 353 | h5 = HDF5Exporter(h5_filename, 32, 3) 354 | np.random.seed(random_seed) 355 | for idx in xrange(num_images): 356 | if idx % 100 == 0: 357 | print '%d / %d\r' % (idx, num_images), 358 | h5.add_images(images[np.newaxis, np.random.randint(images.shape[0], size=3)]) 359 | 360 | print '%-40s\r' % 'Flushing data...', 361 | h5.close() 362 | print '%-40s\r' % '', 363 | print 'Added %d images.' % num_images 364 | 365 | #---------------------------------------------------------------------------- 366 | 367 | def create_cifar10(h5_filename, cifar10_dir, export_labels=False): 368 | print 'Loading CIFAR-10 data from %s' % cifar10_dir 369 | images = [] 370 | labels = [] 371 | for batch in xrange(1, 6): 372 | with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file: 373 | data = pickle.load(file) 374 | images.append(data['data'].reshape(-1, 3, 32, 32)) 375 | labels.append(np.uint8(data['labels'])) 376 | images = np.concatenate(images) 377 | labels = np.concatenate(labels) 378 | 379 | assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8 380 | assert labels.shape == (50000,) and labels.dtype == np.uint8 381 | assert np.min(images) == 0 and np.max(images) == 255 382 | assert np.min(labels) == 0 and np.max(labels) == 9 383 | 384 | print 'Creating %s' % h5_filename 385 | h5 = HDF5Exporter(h5_filename, 32, 3) 386 | h5.add_images(images) 387 | h5.close() 388 | 389 | if export_labels: 390 | npy_filename = os.path.splitext(h5_filename)[0] + '-labels.npy' 391 | print 'Creating %s' % npy_filename 392 | onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32) 393 | onehot[np.arange(labels.size), labels] = 1.0 394 | np.save(npy_filename, onehot) 395 | print 'Added %d images.' % images.shape[0] 396 | 397 | #---------------------------------------------------------------------------- 398 | 399 | def create_lsun(h5_filename, lmdb_dir, resolution=256, max_images=None): 400 | print 'Creating LSUN dataset %s from %s' % (h5_filename, lmdb_dir) 401 | import lmdb # pip install lmdb 402 | import cv2 # pip install opencv-python 403 | with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn: 404 | total_images = txn.stat()['entries'] 405 | if max_images is None: 406 | max_images = total_images 407 | 408 | h5 = HDF5Exporter(h5_filename, resolution, 3) 409 | for idx, (key, value) in enumerate(txn.cursor()): 410 | print '%d / %d\r' % (h5.num_images(), min(h5.num_images() + total_images - idx, max_images)), 411 | try: 412 | try: 413 | img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1) 414 | if img is None: 415 | raise IOError('cv2.imdecode failed') 416 | img = img[:, :, ::-1] # BGR => RGB 417 | except IOError: 418 | img = np.asarray(PIL.Image.open(io.BytesIO(value))) 419 | crop = np.min(img.shape[:2]) 420 | img = img[(img.shape[0] - crop) / 2 : (img.shape[0] + crop) / 2, (img.shape[1] - crop) / 2 : (img.shape[1] + crop) / 2] 421 | img = PIL.Image.fromarray(img, 'RGB') 422 | img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS) 423 | img = np.asarray(img) 424 | img = img.transpose(2, 0, 1) # HWC => CHW 425 | h5.add_images(img[np.newaxis]) 426 | except: 427 | print '%-40s\r' % '', 428 | print sys.exc_info()[1] 429 | raise 430 | if h5.num_images() == max_images: 431 | break 432 | 433 | print '%-40s\r' % 'Flushing data...', 434 | num_added = h5.num_images() 435 | h5.close() 436 | print '%-40s\r' % '', 437 | print 'Added %d images.' % num_added 438 | 439 | #---------------------------------------------------------------------------- 440 | 441 | def create_celeba(h5_filename, celeba_dir, cx=89, cy=121): 442 | print 'Creating CelebA dataset %s from %s' % (h5_filename, celeba_dir) 443 | glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png') 444 | image_filenames = sorted(glob.glob(glob_pattern)) 445 | num_images = 202599 446 | if len(image_filenames) != num_images: 447 | print 'Error: Expected to find %d images in %s' % (num_images, glob_pattern) 448 | return 449 | 450 | h5 = HDF5Exporter(h5_filename, 128, 3) 451 | for idx in xrange(num_images): 452 | print '%d / %d\r' % (idx, num_images), 453 | img = np.asarray(PIL.Image.open(image_filenames[idx])) 454 | assert img.shape == (218, 178, 3) 455 | img = img[cy - 64 : cy + 64, cx - 64 : cx + 64] 456 | img = img.transpose(2, 0, 1) # HWC => CHW 457 | h5.add_images(img[np.newaxis]) 458 | 459 | print '%-40s\r' % 'Flushing data...', 460 | h5.close() 461 | print '%-40s\r' % '', 462 | print 'Added %d images.' % num_images 463 | 464 | #---------------------------------------------------------------------------- 465 | 466 | def create_celeba_hq(h5_filename, celeba_dir, delta_dir, num_threads=4, num_tasks=100): 467 | print 'Loading CelebA data from %s' % celeba_dir 468 | glob_pattern = os.path.join(celeba_dir, 'img_celeba', '*.jpg') 469 | glob_expected = 202599 470 | if len(glob.glob(glob_pattern)) != glob_expected: 471 | print 'Error: Expected to find %d images in %s' % (glob_expected, glob_pattern) 472 | return 473 | with open(os.path.join(celeba_dir, 'Anno', 'list_landmarks_celeba.txt'), 'rt') as file: 474 | landmarks = [[float(value) for value in line.split()[1:]] for line in file.readlines()[2:]] 475 | for i in range(len(landmarks)): 476 | if(len(landmarks[i])!=10): 477 | landmarks[i] = [0]*10 478 | a = np.reshape(landmarks[i],[5,2]) 479 | landmarks[i] = a 480 | landmarks = np.array(landmarks) 481 | print(landmarks.shape) 482 | 483 | print 'Loading CelebA-HQ deltas from %s' % delta_dir 484 | import hashlib 485 | import bz2 486 | import zipfile 487 | import base64 488 | import cryptography.hazmat.primitives.hashes 489 | import cryptography.hazmat.backends 490 | import cryptography.hazmat.primitives.kdf.pbkdf2 491 | import cryptography.fernet 492 | glob_pattern = os.path.join(delta_dir, 'delta*.zip') 493 | glob_expected = 30 494 | if len(glob.glob(glob_pattern)) != glob_expected: 495 | print 'Error: Expected to find %d zips in %s' % (glob_expected, glob_pattern) 496 | return 497 | with open(os.path.join(delta_dir, 'image_list.txt'), 'rt') as file: 498 | lines = [line.split() for line in file] 499 | fields = dict() 500 | for idx, field in enumerate(lines[0]): 501 | type = int if field.endswith('idx') else str 502 | fields[field] = [type(line[idx]) for line in lines[1:]] 503 | 504 | def rot90(v): 505 | return np.array([-v[1], v[0]]) 506 | 507 | def process_func(idx): 508 | # Load original image. 509 | orig_idx = fields['orig_idx'][idx] 510 | orig_file = fields['orig_file'][idx] 511 | orig_path = os.path.join(celeba_dir, 'img_celeba', orig_file) 512 | img = PIL.Image.open(orig_path) 513 | 514 | # Choose oriented crop rectangle. 515 | lm = landmarks[orig_idx] 516 | eye_avg = (lm[0] + lm[1]) * 0.5 + 0.5 517 | mouth_avg = (lm[3] + lm[4]) * 0.5 + 0.5 518 | eye_to_eye = lm[1] - lm[0] 519 | eye_to_mouth = mouth_avg - eye_avg 520 | x = eye_to_eye - rot90(eye_to_mouth) 521 | x /= np.hypot(*x) 522 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 523 | y = rot90(x) 524 | c = eye_avg + eye_to_mouth * 0.1 525 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 526 | zoom = 1024 / (np.hypot(*x) * 2) 527 | 528 | # Shrink. 529 | shrink = int(np.floor(0.5 / zoom)) 530 | if shrink > 1: 531 | size = (int(np.round(float(img.size[0]) / shrink)), int(np.round(float(img.size[1]) / shrink))) 532 | img = img.resize(size, PIL.Image.ANTIALIAS) 533 | quad /= shrink 534 | zoom *= shrink 535 | 536 | # Crop. 537 | border = max(int(np.round(1024 * 0.1 / zoom)), 3) 538 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 539 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 540 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 541 | img = img.crop(crop) 542 | quad -= crop[0:2] 543 | 544 | # Simulate super-resolution. 545 | superres = int(np.exp2(np.ceil(np.log2(zoom)))) 546 | if superres > 1: 547 | img = img.resize((img.size[0] * superres, img.size[1] * superres), PIL.Image.ANTIALIAS) 548 | quad *= superres 549 | zoom /= superres 550 | 551 | # Pad. 552 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 553 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 554 | if max(pad) > border - 4: 555 | pad = np.maximum(pad, int(np.round(1024 * 0.3 / zoom))) 556 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 557 | h, w, _ = img.shape 558 | y, x, _ = np.mgrid[:h, :w, :1] 559 | mask = 1.0 - np.minimum(np.minimum(np.float32(x) / pad[0], np.float32(y) / pad[1]), np.minimum(np.float32(w-1-x) / pad[2], np.float32(h-1-y) / pad[3])) 560 | blur = 1024 * 0.02 / zoom 561 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 562 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) 563 | img = PIL.Image.fromarray(np.uint8(np.clip(np.round(img), 0, 255)), 'RGB') 564 | quad += pad[0:2] 565 | 566 | # Transform. 567 | img = img.transform((4096, 4096), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 568 | img = img.resize((1024, 1024), PIL.Image.ANTIALIAS) 569 | img = np.asarray(img).transpose(2, 0, 1) 570 | 571 | # Load delta image and original JPG. 572 | with zipfile.ZipFile(os.path.join(delta_dir, 'deltas%05d.zip' % (idx - idx % 1000)), 'r') as zip: 573 | delta_bytes = zip.read('delta%05d.dat' % idx) 574 | with open(orig_path, 'rb') as file: 575 | orig_bytes = file.read() 576 | 577 | # Decrypt delta image, using original JPG data as decryption key. 578 | algorithm = cryptography.hazmat.primitives.hashes.SHA256() 579 | backend = cryptography.hazmat.backends.default_backend() 580 | kdf = cryptography.hazmat.primitives.kdf.pbkdf2.PBKDF2HMAC(algorithm=algorithm, length=32, salt=orig_file, iterations=100000, backend=backend) 581 | key = base64.urlsafe_b64encode(kdf.derive(orig_bytes)) 582 | delta = np.frombuffer(bz2.decompress(cryptography.fernet.Fernet(key).decrypt(delta_bytes)), dtype=np.uint8).reshape(3, 1024, 1024) 583 | 584 | # Apply delta image. 585 | img = img + delta 586 | img = np.asarray(img).transpose(1, 2, 0) 587 | img = PIL.Image.fromarray(img, mode='RGB') 588 | img512 = img.resize((512, 512), PIL.Image.ANTIALIAS) 589 | img256 = img.resize((256, 256), PIL.Image.ANTIALIAS) 590 | img128 = img.resize((128, 128), PIL.Image.ANTIALIAS) 591 | img64 = img.resize((64, 64), PIL.Image.ANTIALIAS) 592 | return orig_file, img64, img128, img256, img512, img 593 | 594 | # Save all generated images. 595 | with ThreadPool(num_threads) as pool: 596 | for orig_fn, aimg64, aimg128, aimg256, aimg512, aimg1024 in pool.process_items_concurrently(fields['idx'], process_func=process_func, max_items_in_flight=num_tasks): 597 | aimg64.save('./celeba-hq/celeba-64/'+str(orig_fn)) 598 | aimg128.save('./celeba-hq/celeba-128/'+str(orig_fn)) 599 | aimg256.save('./celeba-hq/celeba-256/'+str(orig_fn)) 600 | aimg512.save('./celeba-hq/celeba-512/'+str(orig_fn)) 601 | aimg1024.save('./celeba-hq/celeba-1024/'+str(orig_fn)) 602 | print(orig_fn) 603 | #---------------------------------------------------------------------------- 604 | 605 | def execute_cmdline(argv): 606 | prog = argv[0] 607 | parser = argparse.ArgumentParser( 608 | prog = prog, 609 | description = 'Tool for creating, extracting, and visualizing HDF5 datasets.', 610 | epilog = 'Type "%s -h" for more information.' % prog) 611 | 612 | subparsers = parser.add_subparsers(dest='command') 613 | def add_command(cmd, desc, example=None): 614 | epilog = 'Example: %s %s' % (prog, example) if example is not None else None 615 | return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog) 616 | 617 | p = add_command( 'inspect', 'Print information about HDF5 dataset.', 618 | 'inspect mnist-32x32.h5') 619 | p.add_argument( 'h5_filename', help='HDF5 file to inspect') 620 | 621 | p = add_command( 'compare', 'Compare two HDF5 datasets.', 622 | 'compare mydataset.h5 mnist-32x32.h5') 623 | p.add_argument( 'first_h5', help='First HDF5 file to compare') 624 | p.add_argument( 'second_h5', help='Second HDF5 file to compare') 625 | 626 | p = add_command( 'display', 'Display images in HDF5 dataset.', 627 | 'display mnist-32x32.h5') 628 | p.add_argument( 'h5_filename', help='HDF5 file to visualize') 629 | p.add_argument( '--start', help='Start index (inclusive)', type=int, default=None) 630 | p.add_argument( '--stop', help='Stop index (exclusive)', type=int, default=None) 631 | p.add_argument( '--step', help='Step between consecutive indices', type=int, default=None) 632 | 633 | p = add_command( 'extract', 'Extract images from HDF5 dataset.', 634 | 'extract mnist-32x32.h5 cifar10-images') 635 | p.add_argument( 'h5_filename', help='HDF5 file to extract') 636 | p.add_argument( 'output_dir', help='Directory to extract the images into') 637 | p.add_argument( '--start', help='Start index (inclusive)', type=int, default=None) 638 | p.add_argument( '--stop', help='Stop index (exclusive)', type=int, default=None) 639 | p.add_argument( '--step', help='Step between consecutive indices', type=int, default=None) 640 | 641 | p = add_command( 'create_custom', 'Create HDF5 dataset for custom images.', 642 | 'create_custom mydataset.h5 myimagedir') 643 | p.add_argument( 'h5_filename', help='HDF5 file to create') 644 | p.add_argument( 'image_dir', help='Directory to read the images from') 645 | 646 | p = add_command( 'create_mnist', 'Create HDF5 dataset for MNIST.', 647 | 'create_mnist mnist-32x32.h5 ~/mnist --export_labels') 648 | p.add_argument( 'h5_filename', help='HDF5 file to create') 649 | p.add_argument( 'mnist_dir', help='Directory to read MNIST data from') 650 | p.add_argument( '--export_labels', help='Create *-labels.npy alongside the HDF5', action='store_true') 651 | 652 | p = add_command( 'create_mnist_rgb', 'Create HDF5 dataset for MNIST-RGB.', 653 | 'create_mnist_rgb mnist-rgb-32x32.h5 ~/mnist') 654 | p.add_argument( 'h5_filename', help='HDF5 file to create') 655 | p.add_argument( 'mnist_dir', help='Directory to read MNIST data from') 656 | p.add_argument( '--num_images', help='Number of composite images to create (default: 1000000)', type=int, default=1000000) 657 | p.add_argument( '--random_seed', help='Random seed (default: 123)', type=int, default=123) 658 | 659 | p = add_command( 'create_cifar10', 'Create HDF5 dataset for CIFAR-10.', 660 | 'create_cifar10 cifar-10-32x32.h5 ~/cifar10 --export_labels') 661 | p.add_argument( 'h5_filename', help='HDF5 file to create') 662 | p.add_argument( 'cifar10_dir', help='Directory to read CIFAR-10 data from') 663 | p.add_argument( '--export_labels', help='Create *-labels.npy alongside the HDF5', action='store_true') 664 | 665 | p = add_command( 'create_lsun', 'Create HDF5 dataset for single LSUN category.', 666 | 'create_lsun lsun-airplane-256x256-100k.h5 ~/lsun/airplane_lmdb --resolution 256 --max_images 100000') 667 | p.add_argument( 'h5_filename', help='HDF5 file to create') 668 | p.add_argument( 'lmdb_dir', help='Directory to read LMDB database from') 669 | p.add_argument( '--resolution', help='Output resolution (default: 256)', type=int, default=256) 670 | p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None) 671 | 672 | p = add_command( 'create_celeba', 'Create HDF5 dataset for CelebA.', 673 | 'create_celeba celeba-128x128.h5 ~/celeba') 674 | p.add_argument( 'h5_filename', help='HDF5 file to create') 675 | p.add_argument( 'celeba_dir', help='Directory to read CelebA data from') 676 | p.add_argument( '--cx', help='Center X coordinate (default: 89)', type=int, default=89) 677 | p.add_argument( '--cy', help='Center Y coordinate (default: 121)', type=int, default=121) 678 | 679 | p = add_command( 'create_celeba_hq', 'Create HDF5 dataset for CelebA-HQ.', 680 | 'create_celeba_hq celeba-hq-1024x1024.h5 ~/celeba ~/celeba-hq-deltas') 681 | p.add_argument( 'h5_filename', help='HDF5 file to create') 682 | p.add_argument( 'celeba_dir', help='Directory to read CelebA data from') 683 | p.add_argument( 'delta_dir', help='Directory to read CelebA-HQ deltas from') 684 | p.add_argument( '--num_threads', help='Number of concurrent threads (default: 4)', type=int, default=4) 685 | p.add_argument( '--num_tasks', help='Number of concurrent processing tasks (default: 100)', type=int, default=100) 686 | 687 | args = parser.parse_args(argv[1:]) 688 | func = globals()[args.command] 689 | del args.command 690 | func(**vars(args)) 691 | 692 | #---------------------------------------------------------------------------- 693 | 694 | if __name__ == "__main__": 695 | execute_cmdline(sys.argv) 696 | 697 | #---------------------------------------------------------------------------- 698 | --------------------------------------------------------------------------------