├── assets ├── G1.png ├── G2.png ├── G3.png ├── G4.png ├── G5.png ├── G6.png ├── G7.png ├── 58600.png ├── AE_G1.png ├── AE_G2.png ├── AE_G3.png ├── AE_G4.png ├── AE_G5.png ├── AE_G6.png ├── AE_G7.png ├── AE_G8.png ├── AE_G9.png ├── model.png ├── 82400_1.png ├── 82400_2.png ├── AE_G10.png ├── AE_G11.png ├── AE_G12.png ├── AE_G13.png ├── AE_G14.png ├── AE_G15.png ├── AE_G16.png ├── AE_G17.png ├── AE_G18.png ├── AE_G19.png ├── AE_G20.png ├── AE_G21.png ├── AE_G22.png ├── AE_G23.png ├── AE_G24.png ├── AE_G25.png ├── AE_G26.png ├── 104050_G.png ├── 107300_G.png ├── 115827_G.png ├── AE_batch.png ├── interp_1.png ├── interp_10.png ├── interp_2.png ├── interp_3.png ├── interp_4.png ├── interp_5.png ├── interp_6.png ├── interp_7.png ├── interp_8.png ├── interp_9.png ├── 104050_AE_G.png ├── 107300_AE_G.png ├── 115827_AE_G.png ├── all_G_z0_64x64.png ├── interp_G0_64x64.png ├── all_G_z0_128x128.png └── interp_G0_128x128.png ├── main.py ├── folder.py ├── data_loader.py ├── .gitignore ├── README.md ├── utils.py ├── config.py ├── models.py ├── download.py ├── layers.py └── trainer.py /assets/G1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/G1.png -------------------------------------------------------------------------------- /assets/G2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/G2.png -------------------------------------------------------------------------------- /assets/G3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/G3.png -------------------------------------------------------------------------------- /assets/G4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/G4.png -------------------------------------------------------------------------------- /assets/G5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/G5.png -------------------------------------------------------------------------------- /assets/G6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/G6.png -------------------------------------------------------------------------------- /assets/G7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/G7.png -------------------------------------------------------------------------------- /assets/58600.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/58600.png -------------------------------------------------------------------------------- /assets/AE_G1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G1.png -------------------------------------------------------------------------------- /assets/AE_G2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G2.png -------------------------------------------------------------------------------- /assets/AE_G3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G3.png -------------------------------------------------------------------------------- /assets/AE_G4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G4.png -------------------------------------------------------------------------------- /assets/AE_G5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G5.png -------------------------------------------------------------------------------- /assets/AE_G6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G6.png -------------------------------------------------------------------------------- /assets/AE_G7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G7.png -------------------------------------------------------------------------------- /assets/AE_G8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G8.png -------------------------------------------------------------------------------- /assets/AE_G9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G9.png -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/model.png -------------------------------------------------------------------------------- /assets/82400_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/82400_1.png -------------------------------------------------------------------------------- /assets/82400_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/82400_2.png -------------------------------------------------------------------------------- /assets/AE_G10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G10.png -------------------------------------------------------------------------------- /assets/AE_G11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G11.png -------------------------------------------------------------------------------- /assets/AE_G12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G12.png -------------------------------------------------------------------------------- /assets/AE_G13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G13.png -------------------------------------------------------------------------------- /assets/AE_G14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G14.png -------------------------------------------------------------------------------- /assets/AE_G15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G15.png -------------------------------------------------------------------------------- /assets/AE_G16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G16.png -------------------------------------------------------------------------------- /assets/AE_G17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G17.png -------------------------------------------------------------------------------- /assets/AE_G18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G18.png -------------------------------------------------------------------------------- /assets/AE_G19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G19.png -------------------------------------------------------------------------------- /assets/AE_G20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G20.png -------------------------------------------------------------------------------- /assets/AE_G21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G21.png -------------------------------------------------------------------------------- /assets/AE_G22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G22.png -------------------------------------------------------------------------------- /assets/AE_G23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G23.png -------------------------------------------------------------------------------- /assets/AE_G24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G24.png -------------------------------------------------------------------------------- /assets/AE_G25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G25.png -------------------------------------------------------------------------------- /assets/AE_G26.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_G26.png -------------------------------------------------------------------------------- /assets/104050_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/104050_G.png -------------------------------------------------------------------------------- /assets/107300_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/107300_G.png -------------------------------------------------------------------------------- /assets/115827_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/115827_G.png -------------------------------------------------------------------------------- /assets/AE_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/AE_batch.png -------------------------------------------------------------------------------- /assets/interp_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_1.png -------------------------------------------------------------------------------- /assets/interp_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_10.png -------------------------------------------------------------------------------- /assets/interp_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_2.png -------------------------------------------------------------------------------- /assets/interp_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_3.png -------------------------------------------------------------------------------- /assets/interp_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_4.png -------------------------------------------------------------------------------- /assets/interp_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_5.png -------------------------------------------------------------------------------- /assets/interp_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_6.png -------------------------------------------------------------------------------- /assets/interp_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_7.png -------------------------------------------------------------------------------- /assets/interp_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_8.png -------------------------------------------------------------------------------- /assets/interp_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_9.png -------------------------------------------------------------------------------- /assets/104050_AE_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/104050_AE_G.png -------------------------------------------------------------------------------- /assets/107300_AE_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/107300_AE_G.png -------------------------------------------------------------------------------- /assets/115827_AE_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/115827_AE_G.png -------------------------------------------------------------------------------- /assets/all_G_z0_64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/all_G_z0_64x64.png -------------------------------------------------------------------------------- /assets/interp_G0_64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_G0_64x64.png -------------------------------------------------------------------------------- /assets/all_G_z0_128x128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/all_G_z0_128x128.png -------------------------------------------------------------------------------- /assets/interp_G0_128x128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tikyau/BEGAN-tensorflow/master/assets/interp_G0_128x128.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from trainer import Trainer 5 | from config import get_config 6 | from data_loader import get_loader 7 | from utils import prepare_dirs_and_logger, save_config 8 | 9 | def main(config): 10 | prepare_dirs_and_logger(config) 11 | 12 | rng = np.random.RandomState(config.random_seed) 13 | tf.set_random_seed(config.random_seed) 14 | 15 | if config.is_train: 16 | data_path = config.data_path 17 | batch_size = config.batch_size 18 | do_shuffle = True 19 | else: 20 | setattr(config, 'batch_size', 64) 21 | if config.test_data_path is None: 22 | data_path = config.data_path 23 | else: 24 | data_path = config.test_data_path 25 | batch_size = config.sample_per_image 26 | do_shuffle = False 27 | 28 | data_loader = get_loader( 29 | data_path, config.batch_size, config.input_scale_size, 30 | config.data_format, config.split) 31 | trainer = Trainer(config, data_loader) 32 | 33 | if config.is_train: 34 | save_config(config) 35 | trainer.train() 36 | else: 37 | if not config.load_path: 38 | raise Exception("[!] You should specify `load_path` to load a pretrained model") 39 | trainer.test() 40 | 41 | if __name__ == "__main__": 42 | config, unparsed = get_config() 43 | main(config) 44 | -------------------------------------------------------------------------------- /folder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | ] 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 14 | 15 | def make_dataset(dir): 16 | images = [] 17 | for root, _, fnames in sorted(os.walk(dir)): 18 | for fname in sorted(fnames): 19 | if is_image_file(fname): 20 | path = os.path.join(root, fname) 21 | item = (path, 0) 22 | images.append(item) 23 | 24 | return images 25 | 26 | def default_loader(path): 27 | return Image.open(path).convert('RGB') 28 | 29 | class ImageFolder(data.Dataset): 30 | 31 | def __init__(self, root, transform=None, target_transform=None, 32 | loader=default_loader): 33 | imgs = make_dataset(root) 34 | if len(imgs) == 0: 35 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 36 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 37 | 38 | print("Found {} images in subfolders of: {}".format(len(imgs), root)) 39 | 40 | self.root = root 41 | self.imgs = imgs 42 | self.transform = transform 43 | self.target_transform = target_transform 44 | self.loader = loader 45 | 46 | def __getitem__(self, index): 47 | path, target = self.imgs[index] 48 | img = self.loader(path) 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | if self.target_transform is not None: 52 | target = self.target_transform(target) 53 | 54 | return img, target 55 | 56 | def __len__(self): 57 | return len(self.imgs) 58 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from glob import glob 4 | import tensorflow as tf 5 | 6 | def get_loader(root, batch_size, scale_size, data_format, split=None, is_grayscale=False, seed=None): 7 | dataset_name = os.path.basename(root) 8 | if dataset_name in ['CelebA'] and split: 9 | root = os.path.join(root, 'splits', split) 10 | 11 | for ext in ["jpg", "png"]: 12 | paths = glob("{}/*.{}".format(root, ext)) 13 | 14 | if ext == "jpg": 15 | tf_decode = tf.image.decode_jpeg 16 | elif ext == "png": 17 | tf_decode = tf.image.decode_png 18 | 19 | if len(paths) != 0: 20 | break 21 | 22 | with Image.open(paths[0]) as img: 23 | w, h = img.size 24 | shape = [h, w, 3] 25 | 26 | filename_queue = tf.train.string_input_producer(list(paths), shuffle=False, seed=seed) 27 | reader = tf.WholeFileReader() 28 | filename, data = reader.read(filename_queue) 29 | image = tf_decode(data, channels=3) 30 | 31 | if is_grayscale: 32 | image = tf.image.rgb_to_grayscale(image) 33 | image.set_shape(shape) 34 | 35 | min_after_dequeue = 5000 36 | capacity = min_after_dequeue + 3 * batch_size 37 | 38 | queue = tf.train.shuffle_batch( 39 | [image], batch_size=batch_size, 40 | num_threads=4, capacity=capacity, 41 | min_after_dequeue=min_after_dequeue, name='synthetic_inputs') 42 | 43 | if dataset_name in ['CelebA']: 44 | queue = tf.image.crop_to_bounding_box(queue, 50, 25, 128, 128) 45 | queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size]) 46 | else: 47 | queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size]) 48 | 49 | if data_format == 'NCHW': 50 | queue = tf.transpose(queue, [0, 3, 1, 2]) 51 | elif data_format == 'NHWC': 52 | pass 53 | else: 54 | raise Exception("[!] Unkown data_format: {}".format(data_format)) 55 | 56 | return tf.to_float(queue) 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | test* 3 | data/hand 4 | data/gaze 5 | data/* 6 | samples 7 | outputs 8 | 9 | # ipython checkpoints 10 | .ipynb_checkpoints 11 | 12 | # Log 13 | logs 14 | 15 | # ETC 16 | paper.pdf 17 | .DS_Store 18 | 19 | # Created by https://www.gitignore.io/api/python,vim 20 | 21 | ### Python ### 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | env/ 33 | build/ 34 | develop-eggs/ 35 | dist/ 36 | downloads/ 37 | eggs/ 38 | .eggs/ 39 | lib/ 40 | lib64/ 41 | parts/ 42 | sdist/ 43 | var/ 44 | wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *,cover 68 | .hypothesis/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # celery beat schedule file 98 | celerybeat-schedule 99 | 100 | # dotenv 101 | .env 102 | 103 | # virtualenv 104 | .venv/ 105 | venv/ 106 | ENV/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | 115 | ### Vim ### 116 | # swap 117 | [._]*.s[a-v][a-z] 118 | [._]*.sw[a-p] 119 | [._]s[a-v][a-z] 120 | [._]sw[a-p] 121 | # session 122 | Session.vim 123 | # temporary 124 | .netrwhist 125 | *~ 126 | # auto-generated tag files 127 | tags 128 | 129 | # End of https://www.gitignore.io/api/python,vim 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BEGAN in Tensorflow 2 | 3 | Tensorflow implementation of [BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717). 4 | 5 | ![alt tag](./assets/model.png) 6 | 7 | 8 | ## Requirements 9 | 10 | - Python 2.7 11 | - [Pillow](https://pillow.readthedocs.io/en/4.0.x/) 12 | - [tqdm](https://github.com/tqdm/tqdm) 13 | - [requests](https://github.com/kennethreitz/requests) (Only used for downloading CelebA dataset) 14 | - [TensorFlow 1.1.0](https://github.com/tensorflow/tensorflow) (**Need nightly build** which can be found in [here](https://github.com/tensorflow/tensorflow#installation), if not you'll see `ValueError: 'image' must be three-dimensional.`) 15 | 16 | 17 | ## Usage 18 | 19 | First download [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) datasets with: 20 | 21 | $ apt-get install p7zip-full # ubuntu 22 | $ brew install p7zip # Mac 23 | $ python download.py 24 | 25 | or you can use your own dataset by placing images like: 26 | 27 | data 28 | └── YOUR_DATASET_NAME 29 | ├── xxx.jpg (name doesn't matter) 30 | ├── yyy.jpg 31 | └── ... 32 | 33 | To train a model: 34 | 35 | $ python main.py --dataset=CelebA --use_gpu=True 36 | $ python main.py --dataset=YOUR_DATASET_NAME --use_gpu=True 37 | 38 | To test a model (use your `load_path`): 39 | 40 | $ python main.py --dataset=CelebA --load_path=CelebA_0405_124806 --use_gpu=True --is_train=False --split valid 41 | 42 | 43 | ## Results 44 | 45 | ### Generator output (64x64) with `gamma=0.5` after 300k steps 46 | 47 | ![all_G_z0_64x64](./assets/all_G_z0_64x64.png) 48 | 49 | 50 | ### Generator output (128x128) with `gamma=0.5` after 200k steps 51 | 52 | ![all_G_z0_64x64](./assets/all_G_z0_128x128.png) 53 | 54 | 55 | ### Interpolation of Generator output (64x64) with `gamma=0.5` after 300k steps 56 | 57 | ![interp_G0_64x64](./assets/interp_G0_64x64.png) 58 | 59 | 60 | ### Interpolation of Generator output (128x128) with `gamma=0.5` after 200k steps 61 | 62 | ![interp_G0_128x128](./assets/interp_G0_128x128.png) 63 | 64 | 65 | ### Interpolation of Discriminator output of real images 66 | 67 | ![alt tag](./assets/AE_batch.png) 68 | ![alt tag](./assets/interp_1.png) 69 | ![alt tag](./assets/interp_2.png) 70 | ![alt tag](./assets/interp_3.png) 71 | ![alt tag](./assets/interp_4.png) 72 | ![alt tag](./assets/interp_5.png) 73 | ![alt tag](./assets/interp_6.png) 74 | ![alt tag](./assets/interp_7.png) 75 | ![alt tag](./assets/interp_8.png) 76 | ![alt tag](./assets/interp_9.png) 77 | ![alt tag](./assets/interp_10.png) 78 | 79 | 80 | ## Related works 81 | 82 | - [DCGAN-tensorflow](https://github.com/carpedm20/DCGAN-tensorflow) 83 | - [DiscoGAN-pytorch](https://github.com/carpedm20/DiscoGAN-pytorch) 84 | - [simulated-unsupervised-tensorflow](https://github.com/carpedm20/simulated-unsupervised-tensorflow) 85 | 86 | 87 | ## Author 88 | 89 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io) 90 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import math 5 | import json 6 | import logging 7 | import numpy as np 8 | from PIL import Image 9 | from datetime import datetime 10 | 11 | def prepare_dirs_and_logger(config): 12 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 13 | logger = logging.getLogger() 14 | 15 | for hdlr in logger.handlers: 16 | logger.removeHandler(hdlr) 17 | 18 | handler = logging.StreamHandler() 19 | handler.setFormatter(formatter) 20 | 21 | logger.addHandler(handler) 22 | 23 | if config.load_path: 24 | if config.load_path.startswith(config.log_dir): 25 | config.model_dir = config.load_path 26 | else: 27 | if config.load_path.startswith(config.dataset): 28 | config.model_name = config.load_path 29 | else: 30 | config.model_name = "{}_{}".format(config.dataset, config.load_path) 31 | else: 32 | config.model_name = "{}_{}".format(config.dataset, get_time()) 33 | 34 | if not hasattr(config, 'model_dir'): 35 | config.model_dir = os.path.join(config.log_dir, config.model_name) 36 | config.data_path = os.path.join(config.data_dir, config.dataset) 37 | 38 | for path in [config.log_dir, config.data_dir, config.model_dir]: 39 | if not os.path.exists(path): 40 | os.makedirs(path) 41 | 42 | def get_time(): 43 | return datetime.now().strftime("%m%d_%H%M%S") 44 | 45 | def save_config(config): 46 | param_path = os.path.join(config.model_dir, "params.json") 47 | 48 | print("[*] MODEL dir: %s" % config.model_dir) 49 | print("[*] PARAM path: %s" % param_path) 50 | 51 | with open(param_path, 'w') as fp: 52 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 53 | 54 | def rank(array): 55 | return len(array.shape) 56 | 57 | def make_grid(tensor, nrow=8, padding=2, 58 | normalize=False, scale_each=False): 59 | """Code based on https://github.com/pytorch/vision/blob/master/torchvision/utils.py""" 60 | nmaps = tensor.shape[0] 61 | xmaps = min(nrow, nmaps) 62 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 63 | height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding) 64 | grid = np.zeros([height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2, 3], dtype=np.uint8) 65 | k = 0 66 | for y in range(ymaps): 67 | for x in range(xmaps): 68 | if k >= nmaps: 69 | break 70 | h, h_width = y * height + 1 + padding // 2, height - padding 71 | w, w_width = x * width + 1 + padding // 2, width - padding 72 | 73 | grid[h:h+h_width, w:w+w_width] = tensor[k] 74 | k = k + 1 75 | return grid 76 | 77 | def save_image(tensor, filename, nrow=8, padding=2, 78 | normalize=False, scale_each=False): 79 | ndarr = make_grid(tensor, nrow=nrow, padding=padding, 80 | normalize=normalize, scale_each=scale_each) 81 | im = Image.fromarray(ndarr) 82 | im.save(filename) 83 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | import argparse 3 | 4 | def str2bool(v): 5 | return v.lower() in ('true', '1') 6 | 7 | arg_lists = [] 8 | parser = argparse.ArgumentParser() 9 | 10 | def add_argument_group(name): 11 | arg = parser.add_argument_group(name) 12 | arg_lists.append(arg) 13 | return arg 14 | 15 | # Network 16 | net_arg = add_argument_group('Network') 17 | net_arg.add_argument('--input_scale_size', type=int, default=64, 18 | help='input image will be resized with the given value as width and height') 19 | net_arg.add_argument('--conv_hidden_num', type=int, default=128, 20 | choices=[64, 128],help='n in the paper') 21 | net_arg.add_argument('--z_num', type=int, default=64, choices=[64, 128]) 22 | 23 | # Data 24 | data_arg = add_argument_group('Data') 25 | data_arg.add_argument('--dataset', type=str, default='CelebA') 26 | data_arg.add_argument('--split', type=str, default='train') 27 | data_arg.add_argument('--batch_size', type=int, default=16) 28 | data_arg.add_argument('--grayscale', type=str2bool, default=False) 29 | data_arg.add_argument('--num_worker', type=int, default=4) 30 | 31 | # Training / test parameters 32 | train_arg = add_argument_group('Training') 33 | train_arg.add_argument('--is_train', type=str2bool, default=True) 34 | train_arg.add_argument('--optimizer', type=str, default='adam') 35 | train_arg.add_argument('--max_step', type=int, default=500000) 36 | train_arg.add_argument('--lr_update_step', type=int, default=100000, choices=[100000, 75000]) 37 | train_arg.add_argument('--d_lr', type=float, default=0.00008) 38 | train_arg.add_argument('--g_lr', type=float, default=0.00008) 39 | train_arg.add_argument('--lr_lower_boundary', type=float, default=0.00002) 40 | train_arg.add_argument('--beta1', type=float, default=0.5) 41 | train_arg.add_argument('--beta2', type=float, default=0.999) 42 | train_arg.add_argument('--gamma', type=float, default=0.5) 43 | train_arg.add_argument('--lambda_k', type=float, default=0.001) 44 | train_arg.add_argument('--use_gpu', type=str2bool, default=True) 45 | 46 | # Misc 47 | misc_arg = add_argument_group('Misc') 48 | misc_arg.add_argument('--load_path', type=str, default='') 49 | misc_arg.add_argument('--log_step', type=int, default=50) 50 | misc_arg.add_argument('--save_step', type=int, default=5000) 51 | misc_arg.add_argument('--num_log_samples', type=int, default=3) 52 | misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN']) 53 | misc_arg.add_argument('--log_dir', type=str, default='logs') 54 | misc_arg.add_argument('--data_dir', type=str, default='data') 55 | misc_arg.add_argument('--test_data_path', type=str, default=None, 56 | help='directory with images which will be used in test sample generation') 57 | misc_arg.add_argument('--sample_per_image', type=int, default=64, 58 | help='# of sample per image during test sample generation') 59 | misc_arg.add_argument('--random_seed', type=int, default=123) 60 | 61 | def get_config(): 62 | config, unparsed = parser.parse_known_args() 63 | if config.use_gpu: 64 | data_format = 'NCHW' 65 | else: 66 | data_format = 'NHWC' 67 | setattr(config, 'data_format', data_format) 68 | return config, unparsed 69 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | 5 | def GeneratorCNN(z, hidden_num, output_num, repeat_num, data_format, reuse): 6 | with tf.variable_scope("G", reuse=reuse) as vs: 7 | num_output = int(np.prod([8, 8, hidden_num])) 8 | x = slim.fully_connected(z, num_output, activation_fn=None) 9 | x = reshape(x, 8, 8, hidden_num, data_format) 10 | 11 | for idx in range(repeat_num): 12 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 13 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 14 | if idx < repeat_num - 1: 15 | x = upscale(x, 2, data_format) 16 | 17 | out = slim.conv2d(x, 3, 3, 1, activation_fn=None, data_format=data_format) 18 | 19 | variables = tf.contrib.framework.get_variables(vs) 20 | return out, variables 21 | 22 | def DiscriminatorCNN(x, input_channel, z_num, repeat_num, hidden_num, data_format): 23 | with tf.variable_scope("D") as vs: 24 | # Encoder 25 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 26 | 27 | prev_channel_num = hidden_num 28 | for idx in range(repeat_num): 29 | channel_num = hidden_num * (idx + 1) 30 | x = slim.conv2d(x, channel_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 31 | x = slim.conv2d(x, channel_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 32 | if idx < repeat_num - 1: 33 | x = slim.conv2d(x, channel_num, 3, 2, activation_fn=tf.nn.elu, data_format=data_format) 34 | #x = tf.contrib.layers.max_pool2d(x, [2, 2], [2, 2], padding='VALID') 35 | 36 | x = tf.reshape(x, [-1, np.prod([8, 8, channel_num])]) 37 | z = x = slim.fully_connected(x, z_num, activation_fn=None) 38 | 39 | # Decoder 40 | num_output = int(np.prod([8, 8, hidden_num])) 41 | x = slim.fully_connected(x, num_output, activation_fn=None) 42 | x = reshape(x, 8, 8, hidden_num, data_format) 43 | 44 | for idx in range(repeat_num): 45 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 46 | x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format) 47 | if idx < repeat_num - 1: 48 | x = upscale(x, 2, data_format) 49 | 50 | out = slim.conv2d(x, input_channel, 3, 1, activation_fn=None, data_format=data_format) 51 | 52 | variables = tf.contrib.framework.get_variables(vs) 53 | return out, z, variables 54 | 55 | def int_shape(tensor): 56 | shape = tensor.get_shape().as_list() 57 | return [num if num is not None else -1 for num in shape] 58 | 59 | def get_conv_shape(tensor, data_format): 60 | shape = int_shape(tensor) 61 | # always return [N, H, W, C] 62 | if data_format == 'NCHW': 63 | return [shape[0], shape[2], shape[3], shape[1]] 64 | elif data_format == 'NHWC': 65 | return shape 66 | 67 | def nchw_to_nhwc(x): 68 | return tf.transpose(x, [0, 2, 3, 1]) 69 | 70 | def nhwc_to_nchw(x): 71 | return tf.transpose(x, [0, 3, 1, 2]) 72 | 73 | def reshape(x, h, w, c, data_format): 74 | if data_format == 'NCHW': 75 | x = tf.reshape(x, [-1, c, h, w]) 76 | else: 77 | x = tf.reshape(x, [-1, h, w, c]) 78 | return x 79 | 80 | def resize_nearest_neighbor(x, new_size, data_format): 81 | if data_format == 'NCHW': 82 | x = nchw_to_nhwc(x) 83 | x = tf.image.resize_nearest_neighbor(x, new_size) 84 | x = nhwc_to_nchw(x) 85 | else: 86 | x = tf.image.resize_nearest_neighbor(x, new_size) 87 | return x 88 | 89 | def upscale(x, scale, data_format): 90 | _, h, w, _ = get_conv_shape(x, data_format) 91 | return resize_nearest_neighbor(x, (h*scale, w*scale), data_format) 92 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modification of 3 | - https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py 4 | - http://stackoverflow.com/a/39225039 5 | """ 6 | from __future__ import print_function 7 | import os 8 | import zipfile 9 | import requests 10 | import subprocess 11 | from tqdm import tqdm 12 | from collections import OrderedDict 13 | 14 | def download_file_from_google_drive(id, destination): 15 | URL = "https://docs.google.com/uc?export=download" 16 | session = requests.Session() 17 | 18 | response = session.get(URL, params={ 'id': id }, stream=True) 19 | token = get_confirm_token(response) 20 | 21 | if token: 22 | params = { 'id' : id, 'confirm' : token } 23 | response = session.get(URL, params=params, stream=True) 24 | 25 | save_response_content(response, destination) 26 | 27 | def get_confirm_token(response): 28 | for key, value in response.cookies.items(): 29 | if key.startswith('download_warning'): 30 | return value 31 | return None 32 | 33 | def save_response_content(response, destination, chunk_size=32*1024): 34 | total_size = int(response.headers.get('content-length', 0)) 35 | with open(destination, "wb") as f: 36 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 37 | unit='B', unit_scale=True, desc=destination): 38 | if chunk: # filter out keep-alive new chunks 39 | f.write(chunk) 40 | 41 | def unzip(filepath): 42 | print("Extracting: " + filepath) 43 | base_path = os.path.dirname(filepath) 44 | with zipfile.ZipFile(filepath) as zf: 45 | zf.extractall(base_path) 46 | os.remove(filepath) 47 | 48 | def download_celeb_a(base_path): 49 | data_path = os.path.join(base_path, 'CelebA') 50 | images_path = os.path.join(data_path, 'images') 51 | if os.path.exists(data_path): 52 | print('[!] Found Celeb-A - skip') 53 | return 54 | 55 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 56 | save_path = os.path.join(base_path, filename) 57 | 58 | if os.path.exists(save_path): 59 | print('[*] {} already exists'.format(save_path)) 60 | else: 61 | download_file_from_google_drive(drive_id, save_path) 62 | 63 | zip_dir = '' 64 | with zipfile.ZipFile(save_path) as zf: 65 | zip_dir = zf.namelist()[0] 66 | zf.extractall(base_path) 67 | if not os.path.exists(data_path): 68 | os.mkdir(data_path) 69 | os.rename(os.path.join(base_path, "img_align_celeba"), images_path) 70 | os.remove(save_path) 71 | 72 | def prepare_data_dir(path = './data'): 73 | if not os.path.exists(path): 74 | os.mkdir(path) 75 | 76 | # check, if file exists, make link 77 | def check_link(in_dir, basename, out_dir): 78 | in_file = os.path.join(in_dir, basename) 79 | if os.path.exists(in_file): 80 | link_file = os.path.join(out_dir, basename) 81 | rel_link = os.path.relpath(in_file, out_dir) 82 | os.symlink(rel_link, link_file) 83 | 84 | def add_splits(base_path): 85 | data_path = os.path.join(base_path, 'CelebA') 86 | images_path = os.path.join(data_path, 'images') 87 | train_dir = os.path.join(data_path, 'splits', 'train') 88 | valid_dir = os.path.join(data_path, 'splits', 'valid') 89 | test_dir = os.path.join(data_path, 'splits', 'test') 90 | if not os.path.exists(train_dir): 91 | os.makedirs(train_dir) 92 | if not os.path.exists(valid_dir): 93 | os.makedirs(valid_dir) 94 | if not os.path.exists(test_dir): 95 | os.makedirs(test_dir) 96 | 97 | # these constants based on the standard CelebA splits 98 | NUM_EXAMPLES = 202599 99 | TRAIN_STOP = 162770 100 | VALID_STOP = 182637 101 | 102 | for i in range(0, TRAIN_STOP): 103 | basename = "{:06d}.jpg".format(i+1) 104 | check_link(images_path, basename, train_dir) 105 | for i in range(TRAIN_STOP, VALID_STOP): 106 | basename = "{:06d}.jpg".format(i+1) 107 | check_link(images_path, basename, valid_dir) 108 | for i in range(VALID_STOP, NUM_EXAMPLES): 109 | basename = "{:06d}.jpg".format(i+1) 110 | check_link(images_path, basename, test_dir) 111 | 112 | if __name__ == '__main__': 113 | base_path = './data' 114 | prepare_data_dir() 115 | download_celeb_a(base_path) 116 | add_splits(base_path) 117 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/david-berthelot/tf_img_tech/blob/master/tfswag/layers.py 2 | import numpy as N 3 | import numpy.linalg as LA 4 | import tensorflow as tf 5 | 6 | __author__ = 'David Berthelot' 7 | 8 | 9 | def unboxn(vin, n): 10 | """vin = (batch, h, w, depth), returns vout = (batch, n*h, n*w, depth), each pixel is duplicated.""" 11 | s = tf.shape(vin) 12 | vout = tf.concat([vin] * (n ** 2), 0) # Poor man's replacement for tf.tile (required for Adversarial Training support). 13 | vout = tf.reshape(vout, [s[0] * (n ** 2), s[1], s[2], s[3]]) 14 | vout = tf.batch_to_space(vout, [[0, 0], [0, 0]], n) 15 | return vout 16 | 17 | 18 | def boxn(vin, n): 19 | """vin = (batch, h, w, depth), returns vout = (batch, h//n, w//n, depth), each pixel is averaged.""" 20 | if n == 1: 21 | return vin 22 | s = tf.shape(vin) 23 | vout = tf.reshape(vin, [s[0], s[1] // n, n, s[2] // n, n, s[3]]) 24 | vout = tf.reduce_mean(vout, [2, 4]) 25 | return vout 26 | 27 | 28 | class LayerBase: 29 | pass 30 | 31 | 32 | class LayerConv(LayerBase): 33 | def __init__(self, name, w, n, nl=lambda x, y: x + y, strides=(1, 1, 1, 1), 34 | padding='SAME', conv=None, use_bias=True, data_format="NCHW"): 35 | """w = (wy, wx), n = (n_in, n_out)""" 36 | self.nl = nl 37 | self.strides = list(strides) 38 | self.padding = padding 39 | self.data_format = data_format 40 | with tf.name_scope(name): 41 | if conv is None: 42 | conv = tf.Variable(tf.truncated_normal([w[0], w[1], n[0], n[1]], stddev=0.01), name='conv') 43 | self.conv = conv 44 | self.bias = tf.Variable(tf.zeros([n[1]]), name='bias') if use_bias else 0 45 | 46 | def __call__(self, vin): 47 | return self.nl(tf.nn.conv2d(vin, self.conv, strides=self.strides, 48 | padding=self.padding, data_format=self.data_format), self.bias) 49 | 50 | class LayerEncodeConvGrowLinear(LayerBase): 51 | def __init__(self, name, n, width, colors, depth, scales, nl=lambda x, y: x + y, data_format="NCHW"): 52 | with tf.variable_scope(name) as vs: 53 | encode = [] 54 | nn = n 55 | for x in range(scales): 56 | cl = [] 57 | for y in range(depth - 1): 58 | cl.append(LayerConv('conv_%d_%d' % (x, y), [width, width], 59 | [nn, nn], nl, data_format=data_format)) 60 | cl.append(LayerConv('conv_%d_%d' % (x, depth - 1), [width, width], 61 | [nn, nn + n], nl, strides=[1, 2, 2, 1], data_format=data_format)) 62 | encode.append(cl) 63 | nn += n 64 | self.encode = [LayerConv('conv_pre', [width, width], [colors, n], nl, data_format=data_format), encode] 65 | self.variables = tf.contrib.framework.get_variables(vs) 66 | 67 | def __call__(self, vin, carry=0, train=True): 68 | vout = self.encode[0](vin) 69 | for convs in self.encode[1]: 70 | for conv in convs[:-1]: 71 | vtmp = tf.nn.elu(conv(vout)) 72 | vout = carry * vout + (1 - carry) * vtmp 73 | vout = convs[-1](vout) 74 | return vout, self.variables 75 | 76 | 77 | class LayerDecodeConvBlend(LayerBase): 78 | def __init__(self, name, n, width, colors, depth, scales, nl=lambda x, y: x + y, data_format="NCHW"): 79 | with tf.variable_scope(name) as vs: 80 | decode = [] 81 | for x in range(scales): 82 | cl = [] 83 | n2 = 2 * n if x else n 84 | cl.append(LayerConv('conv_%d_%d' % (x, 0), [width, width], 85 | [n2, n], nl, data_format=data_format)) 86 | for y in range(1, depth): 87 | cl.append(LayerConv('conv_%d_%d' % (x, y), [width, width], [n, n], nl, data_format=data_format)) 88 | decode.append(cl) 89 | self.decode = [decode, LayerConv('conv_post', [width, width], [n, colors], data_format=data_format)] 90 | self.variables = tf.contrib.framework.get_variables(vs) 91 | 92 | def __call__(self, data, carry, train=True): 93 | vout = data 94 | layers = [] 95 | for x, convs in enumerate(self.decode[0]): 96 | vout = tf.concat([vout, data], 3) if x else vout 97 | vout = unboxn(convs[0](vout), 2) 98 | data = unboxn(data, 2) 99 | for conv in convs[1:]: 100 | vtmp = tf.nn.elu(conv(vout)) 101 | vout = carry * vout + (1 - carry) * vtmp 102 | layers.append(vout) 103 | return self.decode[1](vout), self.variables 104 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import StringIO 5 | import scipy.misc 6 | import numpy as np 7 | from glob import glob 8 | from tqdm import trange 9 | from itertools import chain 10 | from collections import deque 11 | 12 | from models import * 13 | from utils import save_image 14 | 15 | def next(loader): 16 | return loader.next()[0].data.numpy() 17 | 18 | def to_nhwc(image, data_format): 19 | if data_format == 'NCHW': 20 | new_image = nchw_to_nhwc(image) 21 | else: 22 | new_image = image 23 | return new_image 24 | 25 | def to_nchw_numpy(image): 26 | if image.shape[3] in [1, 3]: 27 | new_image = image.transpose([0, 3, 1, 2]) 28 | else: 29 | new_image = image 30 | return new_image 31 | 32 | def norm_img(image, data_format=None): 33 | image = image/127.5 - 1. 34 | if data_format: 35 | image = to_nhwc(image, data_format) 36 | return image 37 | 38 | def denorm_img(norm, data_format): 39 | return tf.clip_by_value(to_nhwc((norm + 1)*127.5, data_format), 0, 255) 40 | 41 | def slerp(val, low, high): 42 | """Code from https://github.com/soumith/dcgan.torch/issues/14""" 43 | omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1)) 44 | so = np.sin(omega) 45 | if so == 0: 46 | return (1.0-val) * low + val * high # L'Hopital's rule/LERP 47 | return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high 48 | 49 | class Trainer(object): 50 | def __init__(self, config, data_loader): 51 | self.config = config 52 | self.data_loader = data_loader 53 | self.dataset = config.dataset 54 | 55 | self.beta1 = config.beta1 56 | self.beta2 = config.beta2 57 | self.optimizer = config.optimizer 58 | self.batch_size = config.batch_size 59 | 60 | self.step = tf.Variable(0, name='step', trainable=False) 61 | 62 | self.g_lr = tf.Variable(config.g_lr, name='g_lr') 63 | self.d_lr = tf.Variable(config.d_lr, name='d_lr') 64 | 65 | self.g_lr_update = tf.assign(self.g_lr, tf.maximum(self.g_lr * 0.5, config.lr_lower_boundary), name='g_lr_update') 66 | self.d_lr_update = tf.assign(self.d_lr, tf.maximum(self.d_lr * 0.5, config.lr_lower_boundary), name='d_lr_update') 67 | 68 | self.gamma = config.gamma 69 | self.lambda_k = config.lambda_k 70 | 71 | self.z_num = config.z_num 72 | self.conv_hidden_num = config.conv_hidden_num 73 | self.input_scale_size = config.input_scale_size 74 | 75 | self.model_dir = config.model_dir 76 | self.load_path = config.load_path 77 | 78 | self.use_gpu = config.use_gpu 79 | self.data_format = config.data_format 80 | 81 | _, height, width, self.channel = \ 82 | get_conv_shape(self.data_loader, self.data_format) 83 | self.repeat_num = int(np.log2(height)) - 2 84 | 85 | self.start_step = 0 86 | self.log_step = config.log_step 87 | self.max_step = config.max_step 88 | self.save_step = config.save_step 89 | self.lr_update_step = config.lr_update_step 90 | 91 | self.is_train = config.is_train 92 | self.build_model() 93 | 94 | self.saver = tf.train.Saver() 95 | self.summary_writer = tf.summary.FileWriter(self.model_dir) 96 | 97 | sv = tf.train.Supervisor(logdir=self.model_dir, 98 | is_chief=True, 99 | saver=self.saver, 100 | summary_op=None, 101 | summary_writer=self.summary_writer, 102 | save_model_secs=300, 103 | global_step=self.step, 104 | ready_for_local_init_op=None) 105 | 106 | gpu_options = tf.GPUOptions(allow_growth=True) 107 | sess_config = tf.ConfigProto(allow_soft_placement=True, 108 | gpu_options=gpu_options) 109 | 110 | self.sess = sv.prepare_or_wait_for_session(config=sess_config) 111 | 112 | if not self.is_train: 113 | # dirty way to bypass graph finilization error 114 | g = tf.get_default_graph() 115 | g._finalized = False 116 | 117 | self.build_test_model() 118 | 119 | def train(self): 120 | z_fixed = np.random.uniform(-1, 1, size=(self.batch_size, self.z_num)) 121 | 122 | x_fixed = self.get_image_from_loader() 123 | save_image(x_fixed, '{}/x_fixed.png'.format(self.model_dir)) 124 | 125 | prev_measure = 1 126 | measure_history = deque([0]*self.lr_update_step, self.lr_update_step) 127 | 128 | for step in trange(self.start_step, self.max_step): 129 | fetch_dict = { 130 | "k_update": self.k_update, 131 | "measure": self.measure, 132 | } 133 | if step % self.log_step == 0: 134 | fetch_dict.update({ 135 | "summary": self.summary_op, 136 | "g_loss": self.g_loss, 137 | "d_loss": self.d_loss, 138 | "k_t": self.k_t, 139 | }) 140 | result = self.sess.run(fetch_dict) 141 | 142 | measure = result['measure'] 143 | measure_history.append(measure) 144 | 145 | if step % self.log_step == 0: 146 | self.summary_writer.add_summary(result['summary'], step) 147 | self.summary_writer.flush() 148 | 149 | g_loss = result['g_loss'] 150 | d_loss = result['d_loss'] 151 | k_t = result['k_t'] 152 | 153 | print("[{}/{}] Loss_D: {:.6f} Loss_G: {:.6f} measure: {:.4f}, k_t: {:.4f}". \ 154 | format(step, self.max_step, d_loss, g_loss, measure, k_t)) 155 | 156 | if step % (self.log_step * 10) == 0: 157 | x_fake = self.generate(z_fixed, self.model_dir, idx=step) 158 | self.autoencode(x_fixed, self.model_dir, idx=step, x_fake=x_fake) 159 | 160 | if step % self.lr_update_step == self.lr_update_step - 1: 161 | self.sess.run([self.g_lr_update, self.d_lr_update]) 162 | #cur_measure = np.mean(measure_history) 163 | #if cur_measure > prev_measure * 0.99: 164 | #prev_measure = cur_measure 165 | 166 | def build_model(self): 167 | self.x = self.data_loader 168 | x = norm_img(self.x) 169 | 170 | self.z = tf.random_uniform( 171 | (tf.shape(x)[0], self.z_num), minval=-1.0, maxval=1.0) 172 | self.k_t = tf.Variable(0., trainable=False, name='k_t') 173 | 174 | G, self.G_var = GeneratorCNN( 175 | self.z, self.conv_hidden_num, self.channel, 176 | self.repeat_num, self.data_format, reuse=False) 177 | 178 | d_out, self.D_z, self.D_var = DiscriminatorCNN( 179 | tf.concat([G, x], 0), self.channel, self.z_num, self.repeat_num, 180 | self.conv_hidden_num, self.data_format) 181 | AE_G, AE_x = tf.split(d_out, 2) 182 | 183 | self.G = denorm_img(G, self.data_format) 184 | self.AE_G, self.AE_x = denorm_img(AE_G, self.data_format), denorm_img(AE_x, self.data_format) 185 | 186 | if self.optimizer == 'adam': 187 | optimizer = tf.train.AdamOptimizer 188 | else: 189 | raise Exception("[!] Caution! Paper didn't use {} opimizer other than Adam".format(config.optimizer)) 190 | 191 | g_optimizer, d_optimizer = optimizer(self.g_lr), optimizer(self.d_lr) 192 | 193 | self.d_loss_real = tf.reduce_mean(tf.abs(AE_x - x)) 194 | self.d_loss_fake = tf.reduce_mean(tf.abs(AE_G - G)) 195 | 196 | self.d_loss = self.d_loss_real - self.k_t * self.d_loss_fake 197 | self.g_loss = tf.reduce_mean(tf.abs(AE_G - G)) 198 | 199 | d_optim = d_optimizer.minimize(self.d_loss, var_list=self.D_var) 200 | g_optim = g_optimizer.minimize(self.g_loss, global_step=self.step, var_list=self.G_var) 201 | 202 | self.balance = self.gamma * self.d_loss_real - self.g_loss 203 | self.measure = self.d_loss_real + tf.abs(self.balance) 204 | 205 | with tf.control_dependencies([d_optim, g_optim]): 206 | self.k_update = tf.assign( 207 | self.k_t, tf.clip_by_value(self.k_t + self.lambda_k * self.balance, 0, 1)) 208 | 209 | self.summary_op = tf.summary.merge([ 210 | tf.summary.image("G", self.G), 211 | tf.summary.image("AE_G", self.AE_G), 212 | tf.summary.image("AE_x", self.AE_x), 213 | 214 | tf.summary.scalar("loss/d_loss", self.d_loss), 215 | tf.summary.scalar("loss/d_loss_real", self.d_loss_real), 216 | tf.summary.scalar("loss/d_loss_fake", self.d_loss_fake), 217 | tf.summary.scalar("loss/g_loss", self.g_loss), 218 | tf.summary.scalar("misc/measure", self.measure), 219 | tf.summary.scalar("misc/k_t", self.k_t), 220 | tf.summary.scalar("misc/d_lr", self.d_lr), 221 | tf.summary.scalar("misc/g_lr", self.g_lr), 222 | tf.summary.scalar("misc/balance", self.balance), 223 | ]) 224 | 225 | def build_test_model(self): 226 | with tf.variable_scope("test") as vs: 227 | # Extra ops for interpolation 228 | z_optimizer = tf.train.AdamOptimizer(0.0001) 229 | 230 | self.z_r = tf.get_variable("z_r", [self.batch_size, self.z_num], tf.float32) 231 | self.z_r_update = tf.assign(self.z_r, self.z) 232 | 233 | G_z_r, _ = GeneratorCNN( 234 | self.z_r, self.conv_hidden_num, self.channel, self.repeat_num, self.data_format, reuse=True) 235 | 236 | with tf.variable_scope("test") as vs: 237 | self.z_r_loss = tf.reduce_mean(tf.abs(self.x - G_z_r)) 238 | self.z_r_optim = z_optimizer.minimize(self.z_r_loss, var_list=[self.z_r]) 239 | 240 | test_variables = tf.contrib.framework.get_variables(vs) 241 | self.sess.run(tf.variables_initializer(test_variables)) 242 | 243 | def generate(self, inputs, root_path=None, path=None, idx=None, save=True): 244 | x = self.sess.run(self.G, {self.z: inputs}) 245 | if path is None and save: 246 | path = os.path.join(root_path, '{}_G.png'.format(idx)) 247 | save_image(x, path) 248 | print("[*] Samples saved: {}".format(path)) 249 | return x 250 | 251 | def autoencode(self, inputs, path, idx=None, x_fake=None): 252 | items = { 253 | 'real': inputs, 254 | 'fake': x_fake, 255 | } 256 | for key, img in items.items(): 257 | if img is None: 258 | continue 259 | if img.shape[3] in [1, 3]: 260 | img = img.transpose([0, 3, 1, 2]) 261 | 262 | x_path = os.path.join(path, '{}_D_{}.png'.format(idx, key)) 263 | x = self.sess.run(self.AE_x, {self.x: img}) 264 | save_image(x, x_path) 265 | print("[*] Samples saved: {}".format(x_path)) 266 | 267 | def encode(self, inputs): 268 | if inputs.shape[3] in [1, 3]: 269 | inputs = inputs.transpose([0, 3, 1, 2]) 270 | return self.sess.run(self.D_z, {self.x: inputs}) 271 | 272 | def decode(self, z): 273 | return self.sess.run(self.AE_x, {self.D_z: z}) 274 | 275 | def interpolate_G(self, real_batch, step=0, root_path='.', train_epoch=0): 276 | batch_size = len(real_batch) 277 | half_batch_size = int(batch_size/2) 278 | 279 | self.sess.run(self.z_r_update) 280 | tf_real_batch = to_nchw_numpy(real_batch) 281 | for i in trange(train_epoch): 282 | z_r_loss, _ = self.sess.run([self.z_r_loss, self.z_r_optim], {self.x: tf_real_batch}) 283 | z = self.sess.run(self.z_r) 284 | 285 | z1, z2 = z[:half_batch_size], z[half_batch_size:] 286 | real1_batch, real2_batch = real_batch[:half_batch_size], real_batch[half_batch_size:] 287 | 288 | generated = [] 289 | for idx, ratio in enumerate(np.linspace(0, 1, 10)): 290 | z = np.stack([slerp(ratio, r1, r2) for r1, r2 in zip(z1, z2)]) 291 | z_decode = self.generate(z, save=False) 292 | generated.append(z_decode) 293 | 294 | generated = np.stack(generated).transpose([1, 0, 2, 3, 4]) 295 | for idx, img in enumerate(generated): 296 | save_image(img, os.path.join(root_path, 'test{}_interp_G_{}.png'.format(step, idx)), nrow=10) 297 | 298 | all_img_num = np.prod(generated.shape[:2]) 299 | batch_generated = np.reshape(generated, [all_img_num] + list(generated.shape[2:])) 300 | save_image(batch_generated, os.path.join(root_path, 'test{}_interp_G.png'.format(step)), nrow=10) 301 | 302 | def interpolate_D(self, real1_batch, real2_batch, step=0, root_path="."): 303 | real1_encode = self.encode(real1_batch) 304 | real2_encode = self.encode(real2_batch) 305 | 306 | decodes = [] 307 | for idx, ratio in enumerate(np.linspace(0, 1, 10)): 308 | z = np.stack([slerp(ratio, r1, r2) for r1, r2 in zip(real1_encode, real2_encode)]) 309 | z_decode = self.decode(z) 310 | decodes.append(z_decode) 311 | 312 | decodes = np.stack(decodes).transpose([1, 0, 2, 3, 4]) 313 | for idx, img in enumerate(decodes): 314 | img = np.concatenate([[real1_batch[idx]], img, [real2_batch[idx]]], 0) 315 | save_image(img, os.path.join(root_path, 'test{}_interp_D_{}.png'.format(step, idx)), nrow=10 + 2) 316 | 317 | def test(self): 318 | root_path = "./"#self.model_dir 319 | 320 | all_G_z = None 321 | for step in range(3): 322 | real1_batch = self.get_image_from_loader() 323 | real2_batch = self.get_image_from_loader() 324 | 325 | save_image(real1_batch, os.path.join(root_path, 'test{}_real1.png'.format(step))) 326 | save_image(real2_batch, os.path.join(root_path, 'test{}_real2.png'.format(step))) 327 | 328 | self.autoencode( 329 | real1_batch, self.model_dir, idx=os.path.join(root_path, "test{}_real1".format(step))) 330 | self.autoencode( 331 | real2_batch, self.model_dir, idx=os.path.join(root_path, "test{}_real2".format(step))) 332 | 333 | self.interpolate_G(real1_batch, step, root_path) 334 | #self.interpolate_D(real1_batch, real2_batch, step, root_path) 335 | 336 | z_fixed = np.random.uniform(-1, 1, size=(self.batch_size, self.z_num)) 337 | G_z = self.generate(z_fixed, path=os.path.join(root_path, "test{}_G_z.png".format(step))) 338 | 339 | if all_G_z is None: 340 | all_G_z = G_z 341 | else: 342 | all_G_z = np.concatenate([all_G_z, G_z]) 343 | save_image(all_G_z, '{}/G_z{}.png'.format(root_path, step)) 344 | 345 | save_image(all_G_z, '{}/all_G_z.png'.format(root_path), nrow=16) 346 | 347 | def get_image_from_loader(self): 348 | x = self.data_loader.eval(session=self.sess) 349 | if self.data_format == 'NCHW': 350 | x = x.transpose([0, 2, 3, 1]) 351 | return x 352 | --------------------------------------------------------------------------------