├── .gitignore ├── README.md ├── requirements.d ├── base.txt └── venv.txt ├── src ├── cache.py ├── cifar10.py ├── dataset.py ├── dataset_.py ├── download.py ├── download2.py ├── inception_score.py ├── jupyter │ ├── .idea │ │ └── vcs.xml │ ├── gan_toy.ipynb │ ├── record_video.py │ ├── tensorflow_utils.py │ └── utils.py ├── main.py ├── plot.py ├── solver.py ├── tensorflow_utils.py ├── utils.py └── wgan_gp.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | src/jupyter/.idea/ 2 | src/jupyter/.ipynb_checkpoints/ 3 | src/jupyter/__pycache__/ 4 | src/jupyter/*.jpg 5 | src/jupyter/img/ 6 | src/jupyter/*.gif 7 | src/jupyter/*.mp4 8 | src/.idea/ 9 | src/mnist/ 10 | src/imagenet/ 11 | src/imagenet64/ 12 | src/cifar10/ 13 | src/__pycache__/ 14 | github_imgs/ 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | pip-wheel-metadata/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .nox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # IPython 94 | profile_default/ 95 | ipython_config.py 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # celery beat schedule file 108 | celerybeat-schedule 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WGAN-GP-tensorflow 2 | This repository is a Tensorflow implementation of the [WGAN-GP](https://arxiv.org/abs/1704.00028) for MNIST, CIFAR-10, and ImageNet64. 3 | 4 |

5 | 7 | 8 | * *All samples in README.md are genearted by neural network except the first image for each row.* 9 | 10 | ## Install Prerequisites 11 | 12 | * python 3.5, 3.6 or 3.7 13 | * python3-tk 14 | 15 | Ubuntu/Debian/etc.: 16 | 17 | sudo apt install python3.5 python3.5-tk 18 | 19 | ## Create Virtual Environment 20 | 21 | python -m venv venv 22 | 23 | ## Activate Virtual Environment 24 | 25 | Windows: 26 | 27 | venv/Scripts/activate 28 | 29 | Bash: 30 | 31 | source venv/bin/activate 32 | 33 | ## Install Virtual Environment Requirements 34 | 35 | pip install -r requirements.d/venv.txt 36 | 37 | ## Create Execution Environments 38 | 39 | tox --notest 40 | 41 | That will install tensorflow which uses only the CPU. 42 | 43 | To use an Nvidia GPU: 44 | 45 | .tox/py35/bin/python -m pip uninstall tensorflow 46 | .tox/py35/bin/python -m pip install tensorflow-gpu==1.13.1 47 | .tox/py36/bin/python -m pip uninstall tensorflow 48 | .tox/py36/bin/python -m pip install tensorflow-gpu==1.13.1 49 | .tox/py37/bin/python -m pip uninstall tensorflow 50 | .tox/py37/bin/python -m pip install tensorflow-gpu==1.13.1 51 | 52 | To use an AMD GPU: 53 | 54 | .tox/py35/bin/python -m pip uninstall tensorflow 55 | .tox/py35/bin/python -m pip install tensorflow-rocm==1.13.1 56 | .tox/py36/bin/python -m pip uninstall tensorflow 57 | .tox/py36/bin/python -m pip install tensorflow-rocm==1.13.1 58 | .tox/py36/bin/python -m pip uninstall tensorflow 59 | .tox/py37/bin/python -m pip install tensorflow-rocm==1.13.1 60 | 61 | ## Generated Images 62 | ### 1. Toy Dataset 63 | Results from 2-dimensional of the 8 Gaussian Mixture Models, 25 Gaussian Mixture Models, and Swiss Roll data. [Ipython Notebook](https://github.com/ChengBinJin/WGAN-GP-tensorflow/tree/master/src/jupyter). 64 | 65 | **Note:** To demonstrate following experiment, we held the generator distribution Pg fixed at the real distribution plus unit-variance Gaussian noise. 66 | - **Top:** GAN discriminator 67 | - **Middle:** WGAN critic with weight clipping 68 | - **Bottom:** WGAN critic with weight penalty 69 |

70 | 71 | 72 | 73 |

74 | 75 | **Note:** For the next experiment, we did not fix generator and showed generated points by the generator. 76 | - **Top:** GAN discriminator 77 | - **Middle:** WGAN critic with weight clipping 78 | - **Bottom:** WGAN critic with weight penalty 79 |

80 | 81 | 82 | 83 |

84 | 85 | ### 2. MNIST Dataset 86 |

87 | 88 |

89 | 90 |

91 | 92 |

93 | 94 | ### 3. CIFAR-10 95 |

96 | 97 |

98 | 99 |

100 | 101 |

102 | 103 | ### 4. IMAGENET64 104 |

105 | 106 |

107 | 108 |

109 | 110 |

111 | 112 | ## Documentation 113 | ### Download Dataset 114 | 'MNIST' and 'CIFAR10' dataset will be downloaded automatically from the code if in a specific folder there are no dataset. 'ImageNet64' dataset can be download from the [Downsampled ImageNet](http://image-net.org/small/download.php). 115 | 116 | ### Directory Hierarchy 117 | ``` 118 | . 119 | │ WGAN-GP 120 | │ ├── src 121 | │ │ ├── imagenet (folder saved inception network weights that downloaded from the inception_score.py) 122 | │ │ ├── cache.py 123 | │ │ ├── cifar10.py 124 | │ │ ├── dataset.py 125 | │ │ ├── dataset_.py 126 | │ │ ├── download.py 127 | │ │ ├── inception_score.py 128 | │ │ ├── main.py 129 | │ │ ├── plot.py 130 | │ │ ├── solver.py 131 | │ │ ├── tensorflow_utils.py 132 | │ │ ├── utils.py 133 | │ │ └── wgan_gp.py 134 | │ Data 135 | │ ├── mnist 136 | │ ├── cifar10 137 | │ └── imagenet64 138 | ``` 139 | **src**: source codes of the WGAN-GP 140 | 141 | ### Training WGAN-GP 142 | Use `main.py` to train a WGAN-GP network. Example usage: 143 | 144 | ``` 145 | python main.py 146 | ``` 147 | - `gpu_index`: gpu index, default: `0` 148 | - `batch_size`: batch size for one feed forward, default: `64` 149 | - `dataset`: dataset name from [mnist, cifar10, imagenet64], default: `mnist` 150 | 151 | - `is_train`: training or inference mode, default: `True` 152 | - `learning_rate`: initial learning rate for Adam, default: `0.001` 153 | - `num_critic`: the number of iterations of the critic per generator iteration, default: `5` 154 | - `z_dim`: dimension of z vector, default: `128` 155 | - `lambda_`: gradient penalty lambda hyperparameter, default: `10.` 156 | - `beta1`: beta1 momentum term of Adam, default: `0.5` 157 | - `beta2`: beta2 momentum term of Adam, default: `0.9` 158 | 159 | - `iters`: number of interations, default: `200000` 160 | - `print_freq`: print frequency for loss, default: `100` 161 | - `save_freq`: save frequency for model, default: `10000` 162 | - `sample_freq`: sample frequency for saving image, default: `500` 163 | - `inception_freq`: calculation frequence of the inception score, default: `1000` 164 | - `sample_batch`: number of sampling images for check generator quality, default: `64` 165 | - `load_model`: folder of save model that you wish to test, (e.g. 20181120-1558). default: `None` 166 | 167 | ### WGAN-GP During Training 168 | **Note:** From the following figures, the Y axises are tge negative critic loss for the WGAN-GP. 169 | 1. **MNIST** 170 |

171 | 172 |

173 | 174 | 2. **CIFAR10** 175 |

176 | 177 |

178 | 179 | 3. **IMAGENET64** 180 |

181 | 182 |

183 | 184 | ### Inception Score on CIFAR10 During Training 185 | **Note:** Inception score was calculated every 1000 iterations. 186 |

187 | 188 |

189 | 190 | ### Test WGAN-GP 191 | Use `main.py` to test a WGAN-GP network. Example usage: 192 | 193 | ``` 194 | python main.py --is_train=false --load_model=folder/you/wish/to/test/e.g./20181120-1558 195 | ``` 196 | Please refer to the above arguments. 197 | 198 | ### Citation 199 | ``` 200 | @misc{chengbinjin2018wgan-gp, 201 | author = {Cheng-Bin Jin}, 202 | title = {WGAN-GP-tensorflow}, 203 | year = {2018}, 204 | howpublished = {\url{https://github.com/ChengBinJin/WGAN-GP-tensorflow}}, 205 | note = {commit xxxxxxx} 206 | } 207 | ``` 208 | 209 | ### Attributions/Thanks 210 | - This project borrowed some code from [igul222](https://github.com/igul222/improved_wgan_training). 211 | - Some readme formatting was borrowed from [Logan Engstrom](https://github.com/lengstrom/fast-style-transfer). 212 | 213 | ## License 214 | Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: sbkim0407@gmail.com). Free for research use, as long as proper attribution is given and this copyright notice is retained. 215 | 216 | ## Related Projects 217 | - [Vanilla GAN](https://github.com/ChengBinJin/VanillaGAN-TensorFlow) 218 | - [DCGAN](https://github.com/ChengBinJin/DCGAN-TensorFlow) 219 | - [WGAN](https://github.com/ChengBinJin/WGAN-TensorFlow) 220 | -------------------------------------------------------------------------------- /requirements.d/base.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | astor==0.7.1 3 | attrs==19.1.0 4 | certifi==2019.3.9 5 | chardet==3.0.4 6 | cycler==0.10.0 7 | gast==0.2.2 8 | grpcio==1.20.1 9 | h5py==2.9.0 10 | idna==2.8 11 | jsonschema==3.0.1 12 | Keras==2.2.4 13 | Keras-Applications==1.0.7 14 | Keras-Preprocessing==1.0.9 15 | kiwisolver==1.1.0 16 | Markdown==3.1 17 | matplotlib==3.0.3 18 | mock==3.0.4 19 | numpy==1.16.3 20 | Pillow==6.2.0 21 | protobuf==3.7.1 22 | pyparsing==2.4.0 23 | pyrsistent==0.15.1 24 | python-dateutil==2.8.0 25 | PyYAML==5.1 26 | requests==2.21.0 27 | scipy==1.2.1 28 | six==1.12.0 29 | tensorboard==1.13.1 30 | tensorflow==1.15.0 31 | tensorflow-estimator==1.13.0 32 | termcolor==1.1.0 33 | tqdm==4.31.1 34 | urllib3==1.25.2 35 | Werkzeug==0.15.3 36 | -------------------------------------------------------------------------------- /requirements.d/venv.txt: -------------------------------------------------------------------------------- 1 | filelock==3.0.10 2 | pluggy==0.9.0 3 | py==1.8.0 4 | six==1.12.0 5 | toml==0.10.0 6 | tox==3.9.0 7 | virtualenv==16.5.0 8 | -------------------------------------------------------------------------------- /src/cache.py: -------------------------------------------------------------------------------- 1 | ######################################################################## 2 | # 3 | # Cache-wrapper for a function or class. 4 | # 5 | # Save the result of calling a function or creating an object-instance 6 | # to harddisk. This is used to persist the data so it can be reloaded 7 | # very quickly and easily. 8 | # 9 | # Implemented in Python 3.5 10 | # 11 | ######################################################################## 12 | # 13 | # This file is part of the TensorFlow Tutorials available at: 14 | # 15 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials 16 | # 17 | # Published under the MIT License. See the file LICENSE for details. 18 | # 19 | # Copyright 2016 by Magnus Erik Hvass Pedersen 20 | # 21 | ######################################################################## 22 | 23 | import os 24 | import pickle 25 | import numpy as np 26 | 27 | ######################################################################## 28 | 29 | 30 | def cache(cache_path, fn, *args, **kwargs): 31 | """ 32 | Cache-wrapper for a function or class. If the cache-file exists 33 | then the data is reloaded and returned, otherwise the function 34 | is called and the result is saved to cache. The fn-argument can 35 | also be a class instead, in which case an object-instance is 36 | created and saved to the cache-file. 37 | :param cache_path: 38 | File-path for the cache-file. 39 | :param fn: 40 | Function or class to be called. 41 | :param args: 42 | Arguments to the function or class-init. 43 | :param kwargs: 44 | Keyword arguments to the function or class-init. 45 | :return: 46 | The result of calling the function or creating the object-instance. 47 | """ 48 | 49 | # If the cache-file exists. 50 | if os.path.exists(cache_path): 51 | # Load the cached data from the file. 52 | with open(cache_path, mode='rb') as file: 53 | obj_ = pickle.load(file) 54 | 55 | print("- Data loaded from cache-file: " + cache_path) 56 | else: 57 | # The cache-file does not exist. 58 | 59 | # Call the function / class-init with the supplied arguments. 60 | obj_ = fn(*args, **kwargs) 61 | 62 | # Save the data to a cache-file. 63 | with open(cache_path, mode='wb') as file: 64 | pickle.dump(obj_, file) 65 | 66 | print("- Data saved to cache-file: " + cache_path) 67 | 68 | return obj_ 69 | 70 | 71 | ######################################################################## 72 | 73 | 74 | def convert_numpy2pickle(in_path, out_path): 75 | """ 76 | Convert a numpy-file to pickle-file. 77 | The first version of the cache-function used numpy for saving the data. 78 | Instead of re-calculating all the data, you can just convert the 79 | cache-file using this function. 80 | :param in_path: 81 | Input file in numpy-format written using numpy.save(). 82 | :param out_path: 83 | Output file written as a pickle-file. 84 | :return: 85 | Nothing. 86 | """ 87 | 88 | # Load the data using numpy. 89 | data = np.load(in_path) 90 | 91 | # Save the data using pickle. 92 | with open(out_path, mode='wb') as file: 93 | pickle.dump(data, file) 94 | 95 | 96 | ######################################################################## 97 | 98 | if __name__ == '__main__': 99 | # This is a short example of using a cache-file. 100 | 101 | # This is the function that will only get called if the result 102 | # is not already saved in the cache-file. This would normally 103 | # be a function that takes a long time to compute, or if you 104 | # need persistent data for some other reason. 105 | def expensive_function(a, b): 106 | return a * b 107 | 108 | print('Computing expensive_function() ...') 109 | 110 | # Either load the result from a cache-file if it already exists, 111 | # otherwise calculate expensive_function(a=123, b=456) and 112 | # save the result to the cache-file for next time. 113 | result = cache(cache_path='cache_expensive_function.pkl', 114 | fn=expensive_function, a=123, b=456) 115 | 116 | print('result =', result) 117 | 118 | # Newline. 119 | print() 120 | 121 | # This is another example which saves an object to a cache-file. 122 | 123 | # We want to cache an object-instance of this class. 124 | # The motivation is to do an expensive computation only once, 125 | # or if we need to persist the data for some other reason. 126 | class ExpensiveClass: 127 | def __init__(self, c, d): 128 | self.c = c 129 | self.d = d 130 | self.result = c * d 131 | 132 | def print_result(self): 133 | print('c =', self.c) 134 | print('d =', self.d) 135 | print('result = c * d =', self.result) 136 | 137 | print('Creating object from ExpensiveClass() ...') 138 | 139 | # Either load the object from a cache-file if it already exists, 140 | # otherwise make an object-instance ExpensiveClass(c=123, d=456) 141 | # and save the object to the cache-file for the next time. 142 | obj = cache(cache_path='cache_ExpensiveClass.pkl', 143 | fn=ExpensiveClass, c=123, d=456) 144 | 145 | obj.print_result() 146 | 147 | ######################################################################## 148 | -------------------------------------------------------------------------------- /src/cifar10.py: -------------------------------------------------------------------------------- 1 | ######################################################################## 2 | # 3 | # Functions for downloading the CIFAR-10 data-set from the internet 4 | # and loading it into memory. 5 | # 6 | # Implemented in Python 3.5 7 | # 8 | # Usage: 9 | # 1) Set the variable data_path with the desired storage path. 10 | # 2) Call maybe_download_and_extract() to download the data-set 11 | # if it is not already located in the given data_path. 12 | # 3) Call load_class_names() to get an array of the class-names. 13 | # 4) Call load_training_data() and load_test_data() to get 14 | # the images, class-numbers and one-hot encoded class-labels 15 | # for the training-set and test-set. 16 | # 5) Use the returned data in your own program. 17 | # 18 | # Format: 19 | # The images for the training- and test-sets are returned as 4-dim numpy 20 | # arrays each with the shape: [image_number, height, width, channel] 21 | # where the individual pixels are floats between 0.0 and 1.0. 22 | # 23 | ######################################################################## 24 | # 25 | # This file is part of the TensorFlow Tutorials available at: 26 | # 27 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials 28 | # 29 | # Published under the MIT License. See the file LICENSE for details. 30 | # 31 | # Copyright 2016 by Magnus Erik Hvass Pedersen 32 | # 33 | ######################################################################## 34 | 35 | import numpy as np 36 | import pickle 37 | import os 38 | import download 39 | from dataset import one_hot_encoded 40 | 41 | ######################################################################## 42 | 43 | # Directory where you want to download and save the data-set. 44 | # Set this before you start calling any of the functions below. 45 | data_path = "../../Data/cifar10/" 46 | 47 | # URL for the data-set on the internet. 48 | data_url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 49 | 50 | ######################################################################## 51 | # Various constants for the size of the images. 52 | # Use these constants in your own program. 53 | 54 | # Width and height of each image. 55 | img_size = 32 56 | 57 | # Number of channels in each image, 3 channels: Red, Green, Blue. 58 | num_channels = 3 59 | 60 | # Length of an image when flattened to a 1-dim array. 61 | img_size_flat = img_size * img_size * num_channels 62 | 63 | # Number of classes. 64 | num_classes = 10 65 | 66 | ######################################################################## 67 | # Various constants used to allocate arrays of the correct size. 68 | 69 | # Number of files for the training-set. 70 | _num_files_train = 5 71 | 72 | # Number of images for each batch-file in the training-set. 73 | _images_per_file = 10000 74 | 75 | # Total number of images in the training-set. 76 | # This is used to pre-allocate arrays for efficiency. 77 | _num_images_train = _num_files_train * _images_per_file 78 | 79 | ######################################################################## 80 | # Private functions for downloading, unpacking and loading data-files. 81 | 82 | 83 | def _get_file_path(filename=""): 84 | """ 85 | Return the full path of a data-file for the data-set. 86 | If filename=="" then return the directory of the files. 87 | """ 88 | 89 | return os.path.join(data_path, "cifar-10-batches-py/", filename) 90 | 91 | 92 | def _unpickle(filename): 93 | """ 94 | Unpickle the given file and return the data. 95 | Note that the appropriate dir-name is prepended the filename. 96 | """ 97 | 98 | # Create full path for the file. 99 | file_path = _get_file_path(filename) 100 | 101 | print("Loading data: " + file_path) 102 | 103 | with open(file_path, mode='rb') as file: 104 | # In Python 3.X it is important to set the encoding, 105 | # otherwise an exception is raised here. 106 | data = pickle.load(file, encoding='bytes') 107 | 108 | return data 109 | 110 | 111 | def _convert_images(raw): 112 | """ 113 | Convert images from the CIFAR-10 format and 114 | return a 4-dim array with shape: [image_number, height, width, channel] 115 | where the pixels are floats between 0.0 and 1.0. 116 | """ 117 | 118 | # Convert the raw images from the data-files to floating-points. 119 | raw_float = np.array(raw, dtype=float) / 255.0 120 | 121 | # Reshape the array to 4-dimensions. 122 | images = raw_float.reshape([-1, num_channels, img_size, img_size]) 123 | 124 | # Reorder the indices of the array. 125 | images = images.transpose([0, 2, 3, 1]) 126 | 127 | return images 128 | 129 | 130 | def _load_data(filename): 131 | """ 132 | Load a pickled data-file from the CIFAR-10 data-set 133 | and return the converted images (see above) and the class-number 134 | for each image. 135 | """ 136 | 137 | # Load the pickled data-file. 138 | data = _unpickle(filename) 139 | 140 | # Get the raw images. 141 | raw_images = data[b'data'] 142 | 143 | # Get the class-numbers for each image. Convert to numpy-array. 144 | cls = np.array(data[b'labels']) 145 | 146 | # Convert the images. 147 | images = _convert_images(raw_images) 148 | 149 | return images, cls 150 | 151 | 152 | ######################################################################## 153 | # Public functions that you may call to download the data-set from 154 | # the internet and load the data into memory. 155 | 156 | 157 | def maybe_download_and_extract(): 158 | """ 159 | Download and extract the CIFAR-10 data-set if it doesn't already exist 160 | in data_path (set this variable first to the desired path). 161 | """ 162 | 163 | download.maybe_download_and_extract(url=data_url, download_dir=data_path) 164 | 165 | 166 | def load_class_names(): 167 | """ 168 | Load the names for the classes in the CIFAR-10 data-set. 169 | Returns a list with the names. Example: names[3] is the name 170 | associated with class-number 3. 171 | """ 172 | 173 | # Load the class-names from the pickled file. 174 | raw = _unpickle(filename="batches.meta")[b'label_names'] 175 | 176 | # Convert from binary strings. 177 | names = [x.decode('utf-8') for x in raw] 178 | 179 | return names 180 | 181 | 182 | def load_training_data(): 183 | """ 184 | Load all the training-data for the CIFAR-10 data-set. 185 | The data-set is split into 5 data-files which are merged here. 186 | Returns the images, class-numbers and one-hot encoded class-labels. 187 | """ 188 | 189 | # Pre-allocate the arrays for the images and class-numbers for efficiency. 190 | images = np.zeros(shape=[_num_images_train, img_size, img_size, num_channels], dtype=float) 191 | cls = np.zeros(shape=[_num_images_train], dtype=int) 192 | 193 | # Begin-index for the current batch. 194 | begin = 0 195 | 196 | # For each data-file. 197 | for i in range(_num_files_train): 198 | # Load the images and class-numbers from the data-file. 199 | images_batch, cls_batch = _load_data(filename="data_batch_" + str(i + 1)) 200 | 201 | # Number of images in this batch. 202 | num_images = len(images_batch) 203 | 204 | # End-index for the current batch. 205 | end = begin + num_images 206 | 207 | # Store the images into the array. 208 | images[begin:end, :] = images_batch 209 | 210 | # Store the class-numbers into the array. 211 | cls[begin:end] = cls_batch 212 | 213 | # The begin-index for the next batch is the current end-index. 214 | begin = end 215 | 216 | return images, cls, one_hot_encoded(class_numbers=cls, num_classes=num_classes) 217 | 218 | 219 | def load_test_data(): 220 | """ 221 | Load all the test-data for the CIFAR-10 data-set. 222 | Returns the images, class-numbers and one-hot encoded class-labels. 223 | """ 224 | 225 | images, cls = _load_data(filename="test_batch") 226 | 227 | return images, cls, one_hot_encoded(class_numbers=cls, num_classes=num_classes) 228 | 229 | ######################################################################## 230 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | ######################################################################## 2 | # 3 | # Class for creating a data-set consisting of all files in a directory. 4 | # 5 | # Example usage is shown in the file knifey.py and Tutorial #09. 6 | # 7 | # Implemented in Python 3.5 8 | # 9 | ######################################################################## 10 | # 11 | # This file is part of the TensorFlow Tutorials available at: 12 | # 13 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials 14 | # 15 | # Published under the MIT License. See the file LICENSE for details. 16 | # 17 | # Copyright 2016 by Magnus Erik Hvass Pedersen 18 | # 19 | ######################################################################## 20 | 21 | import numpy as np 22 | import os 23 | import shutil 24 | from cache import cache 25 | 26 | 27 | ######################################################################## 28 | 29 | 30 | def one_hot_encoded(class_numbers, num_classes=None): 31 | """ 32 | Generate the One-Hot encoded class-labels from an array of integers. 33 | For example, if class_number=2 and num_classes=4 then 34 | the one-hot encoded label is the float array: [0. 0. 1. 0.] 35 | :param class_numbers: 36 | Array of integers with class-numbers. 37 | Assume the integers are from zero to num_classes-1 inclusive. 38 | :param num_classes: 39 | Number of classes. If None then use max(class_numbers)+1. 40 | :return: 41 | 2-dim array of shape: [len(class_numbers), num_classes] 42 | """ 43 | 44 | # Find the number of classes if None is provided. 45 | # Assumes the lowest class-number is zero. 46 | if num_classes is None: 47 | num_classes = np.max(class_numbers) + 1 48 | 49 | return np.eye(num_classes, dtype=float)[class_numbers] 50 | 51 | 52 | ######################################################################## 53 | 54 | 55 | class DataSet: 56 | def __init__(self, in_dir, exts='.jpg'): 57 | """ 58 | Create a data-set consisting of the filenames in the given directory 59 | and sub-dirs that match the given filename-extensions. 60 | For example, the knifey-spoony data-set (see knifey.py) has the 61 | following dir-structure: 62 | knifey-spoony/forky/ 63 | knifey-spoony/knifey/ 64 | knifey-spoony/spoony/ 65 | knifey-spoony/forky/test/ 66 | knifey-spoony/knifey/test/ 67 | knifey-spoony/spoony/test/ 68 | This means there are 3 classes called: forky, knifey, and spoony. 69 | If we set in_dir = "knifey-spoony/" and create a new DataSet-object 70 | then it will scan through these directories and create a training-set 71 | and test-set for each of these classes. 72 | The training-set will contain a list of all the *.jpg filenames 73 | in the following directories: 74 | knifey-spoony/forky/ 75 | knifey-spoony/knifey/ 76 | knifey-spoony/spoony/ 77 | The test-set will contain a list of all the *.jpg filenames 78 | in the following directories: 79 | knifey-spoony/forky/test/ 80 | knifey-spoony/knifey/test/ 81 | knifey-spoony/spoony/test/ 82 | See the TensorFlow Tutorial #09 for a usage example. 83 | :param in_dir: 84 | Root-dir for the files in the data-set. 85 | This would be 'knifey-spoony/' in the example above. 86 | :param exts: 87 | String or tuple of strings with valid filename-extensions. 88 | Not case-sensitive. 89 | :return: 90 | Object instance. 91 | """ 92 | 93 | # Extend the input directory to the full path. 94 | in_dir = os.path.abspath(in_dir) 95 | 96 | # Input directory. 97 | self.in_dir = in_dir 98 | 99 | # Convert all file-extensions to lower-case. 100 | self.exts = tuple(ext.lower() for ext in exts) 101 | 102 | # Names for the classes. 103 | self.class_names = [] 104 | 105 | # Filenames for all the files in the training-set. 106 | self.filenames = [] 107 | 108 | # Filenames for all the files in the test-set. 109 | self.filenames_test = [] 110 | 111 | # Class-number for each file in the training-set. 112 | self.class_numbers = [] 113 | 114 | # Class-number for each file in the test-set. 115 | self.class_numbers_test = [] 116 | 117 | # Total number of classes in the data-set. 118 | self.num_classes = 0 119 | 120 | # For all files/dirs in the input directory. 121 | for name in os.listdir(in_dir): 122 | # Full path for the file / dir. 123 | current_dir = os.path.join(in_dir, name) 124 | 125 | # If it is a directory. 126 | if os.path.isdir(current_dir): 127 | # Add the dir-name to the list of class-names. 128 | self.class_names.append(name) 129 | 130 | # Training-set. 131 | 132 | # Get all the valid filenames in the dir (not sub-dirs). 133 | filenames = self._get_filenames(current_dir) 134 | 135 | # Append them to the list of all filenames for the training-set. 136 | self.filenames.extend(filenames) 137 | 138 | # The class-number for this class. 139 | class_number = self.num_classes 140 | 141 | # Create an array of class-numbers. 142 | class_numbers = [class_number] * len(filenames) 143 | 144 | # Append them to the list of all class-numbers for the training-set. 145 | self.class_numbers.extend(class_numbers) 146 | 147 | # Test-set. 148 | 149 | # Get all the valid filenames in the sub-dir named 'test'. 150 | filenames_test = self._get_filenames(os.path.join(current_dir, 'test')) 151 | 152 | # Append them to the list of all filenames for the test-set. 153 | self.filenames_test.extend(filenames_test) 154 | 155 | # Create an array of class-numbers. 156 | class_numbers = [class_number] * len(filenames_test) 157 | 158 | # Append them to the list of all class-numbers for the test-set. 159 | self.class_numbers_test.extend(class_numbers) 160 | 161 | # Increase the total number of classes in the data-set. 162 | self.num_classes += 1 163 | 164 | def _get_filenames(self, dir_): 165 | """ 166 | Create and return a list of filenames with matching extensions in the given directory. 167 | :param dir_: 168 | Directory to scan for files. Sub-dirs are not scanned. 169 | :return: 170 | List of filenames. Only filenames. Does not include the directory. 171 | """ 172 | 173 | # Initialize empty list. 174 | filenames = [] 175 | 176 | # If the directory exists. 177 | if os.path.exists(dir_): 178 | # Get all the filenames with matching extensions. 179 | for filename in os.listdir(dir_): 180 | if filename.lower().endswith(self.exts): 181 | filenames.append(filename) 182 | 183 | return filenames 184 | 185 | def get_paths(self, test=False): 186 | """ 187 | Get the full paths for the files in the data-set. 188 | :param test: 189 | Boolean. Return the paths for the test-set (True) or training-set (False). 190 | :return: 191 | Iterator with strings for the path-names. 192 | """ 193 | 194 | if test: 195 | # Use the filenames and class-numbers for the test-set. 196 | filenames = self.filenames_test 197 | class_numbers = self.class_numbers_test 198 | 199 | # Sub-dir for test-set. 200 | test_dir = "test/" 201 | else: 202 | # Use the filenames and class-numbers for the training-set. 203 | filenames = self.filenames 204 | class_numbers = self.class_numbers 205 | 206 | # Don't use a sub-dir for test-set. 207 | test_dir = "" 208 | 209 | for filename, cls in zip(filenames, class_numbers): 210 | # Full path-name for the file. 211 | path = os.path.join(self.in_dir, self.class_names[cls], test_dir, filename) 212 | 213 | yield path 214 | 215 | def get_training_set(self): 216 | """ 217 | Return the list of paths for the files in the training-set, 218 | and the list of class-numbers as integers, 219 | and the class-numbers as one-hot encoded arrays. 220 | """ 221 | 222 | return list(self.get_paths()), np.asarray(self.class_numbers), one_hot_encoded( 223 | class_numbers=self.class_numbers, num_classes=self.num_classes) 224 | 225 | def get_test_set(self): 226 | """ 227 | Return the list of paths for the files in the test-set, 228 | and the list of class-numbers as integers, 229 | and the class-numbers as one-hot encoded arrays. 230 | """ 231 | 232 | return list(self.get_paths(test=True)), np.asarray(self.class_numbers_test), one_hot_encoded( 233 | class_numbers=self.class_numbers_test, num_classes=self.num_classes) 234 | 235 | def copy_files(self, train_dir, test_dir): 236 | """ 237 | Copy all the files in the training-set to train_dir 238 | and copy all the files in the test-set to test_dir. 239 | For example, the normal directory structure for the 240 | different classes in the training-set is: 241 | knifey-spoony/forky/ 242 | knifey-spoony/knifey/ 243 | knifey-spoony/spoony/ 244 | Normally the test-set is a sub-dir of the training-set: 245 | knifey-spoony/forky/test/ 246 | knifey-spoony/knifey/test/ 247 | knifey-spoony/spoony/test/ 248 | But some APIs use another dir-structure for the training-set: 249 | 250 | knifey-spoony/train/forky/ 251 | knifey-spoony/train/knifey/ 252 | knifey-spoony/train/spoony/ 253 | and for the test-set: 254 | 255 | knifey-spoony/test/forky/ 256 | knifey-spoony/test/knifey/ 257 | knifey-spoony/test/spoony/ 258 | :param train_dir: Directory for the training-set e.g. 'knifey-spoony/train/' 259 | :param test_dir: Directory for the test-set e.g. 'knifey-spoony/test/' 260 | :return: Nothing. 261 | """ 262 | 263 | # Helper-function for actually copying the files. 264 | def _copy_files(src_paths, dst_dir, class_numbers): 265 | 266 | # Create a list of dirs for each class, e.g.: 267 | # ['knifey-spoony/test/forky/', 268 | # 'knifey-spoony/test/knifey/', 269 | # 'knifey-spoony/test/spoony/'] 270 | class_dirs = [os.path.join(dst_dir, class_name + "/") 271 | for class_name in self.class_names] 272 | 273 | # Check if each class-directory exists, otherwise create it. 274 | for dir_ in class_dirs: 275 | if not os.path.exists(dir_): 276 | os.makedirs(dir_) 277 | 278 | # For all the file-paths and associated class-numbers, 279 | # copy the file to the destination dir for that class. 280 | for src, cls in zip(src_paths, class_numbers): 281 | shutil.copy(src=src, dst=class_dirs[cls]) 282 | 283 | # Copy the files for the training-set. 284 | _copy_files(src_paths=self.get_paths(test=False), 285 | dst_dir=train_dir, 286 | class_numbers=self.class_numbers) 287 | 288 | print("- Copied training-set to:", train_dir) 289 | 290 | # Copy the files for the test-set. 291 | _copy_files(src_paths=self.get_paths(test=True), 292 | dst_dir=test_dir, 293 | class_numbers=self.class_numbers_test) 294 | 295 | print("- Copied test-set to:", test_dir) 296 | 297 | 298 | ######################################################################## 299 | 300 | 301 | def load_cached(cache_path, in_dir): 302 | """ 303 | Wrapper-function for creating a DataSet-object, which will be 304 | loaded from a cache-file if it already exists, otherwise a new 305 | object will be created and saved to the cache-file. 306 | This is useful if you need to ensure the ordering of the 307 | filenames is consistent every time you load the data-set, 308 | for example if you use the DataSet-object in combination 309 | with Transfer Values saved to another cache-file, see e.g. 310 | Tutorial #09 for an example of this. 311 | :param cache_path: 312 | File-path for the cache-file. 313 | :param in_dir: 314 | Root-dir for the files in the data-set. 315 | This is an argument for the DataSet-init function. 316 | :return: 317 | The DataSet-object. 318 | """ 319 | 320 | print("Creating dataset from the files in: " + in_dir) 321 | 322 | # If the object-instance for DataSet(in_dir=data_dir) already 323 | # exists in the cache-file then reload it, otherwise create 324 | # an object instance and save it to the cache-file for next time. 325 | dataset = cache(cache_path=cache_path, 326 | fn=DataSet, in_dir=in_dir) 327 | 328 | return dataset 329 | 330 | ######################################################################## 331 | -------------------------------------------------------------------------------- /src/dataset_.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow WGAN-GP Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import logging 9 | import numpy as np 10 | import scipy.misc 11 | import tensorflow as tf 12 | 13 | import utils as utils 14 | 15 | logger = logging.getLogger(__name__) # logger 16 | logger.setLevel(logging.INFO) 17 | 18 | 19 | def _init_logger(flags, log_path): 20 | if flags.is_train: 21 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s') 22 | # file handler 23 | file_handler = logging.FileHandler(os.path.join(log_path, 'dataset.log')) 24 | file_handler.setFormatter(formatter) 25 | file_handler.setLevel(logging.INFO) 26 | # stream handler 27 | stream_handler = logging.StreamHandler() 28 | stream_handler.setFormatter(formatter) 29 | # add handlers 30 | logger.addHandler(file_handler) 31 | logger.addHandler(stream_handler) 32 | 33 | 34 | class MnistDataset(object): 35 | def __init__(self, sess, flags, dataset_name): 36 | self.sess = sess 37 | self.flags = flags 38 | self.dataset_name = dataset_name 39 | self.image_size = (32, 32, 1) 40 | self.img_buffle = 100000 # image buffer for image shuflling 41 | self.num_trains, self.num_tests = 0, 0 42 | 43 | self.mnist_path = os.path.join('../../Data', self.dataset_name) 44 | self._load_mnist() 45 | 46 | def _load_mnist(self): 47 | logger.info('Load {} dataset...'.format(self.dataset_name)) 48 | self.train_data, self.test_data = tf.keras.datasets.mnist.load_data() 49 | # self.train_data is tuple: (image, label) 50 | self.num_trains = self.train_data[0].shape[0] 51 | self.num_tests = self.test_data[0].shape[0] 52 | 53 | # TensorFlow Dataset API 54 | train_x, train_y = self.train_data 55 | dataset = tf.data.Dataset.from_tensor_slices(({'image': train_x}, train_y)) 56 | dataset = dataset.shuffle(self.img_buffle).repeat().batch(self.flags.batch_size) 57 | 58 | iterator = dataset.make_one_shot_iterator() 59 | self.next_batch = iterator.get_next() 60 | 61 | logger.info('Load {} dataset SUCCESS!'.format(self.dataset_name)) 62 | logger.info('Img size: {}'.format(self.image_size)) 63 | logger.info('Num. of training data: {}'.format(self.num_trains)) 64 | 65 | def train_next_batch(self, batch_size): 66 | batch_data = self.sess.run(self.next_batch) 67 | batch_imgs = batch_data[0]["image"] 68 | # batch_labels = batch_data[1] 69 | 70 | if self.flags.batch_size > batch_size: 71 | # reshape 784 vector to 28 x 28 x 1 72 | batch_imgs = np.reshape(batch_imgs[:batch_size], [batch_size, 28, 28]) 73 | else: 74 | batch_imgs = np.reshape(batch_imgs, [self.flags.batch_size, 28, 28]) 75 | 76 | imgs_32 = [scipy.misc.imresize(batch_imgs[idx], self.image_size[0:2]) 77 | for idx in range(batch_imgs.shape[0])] # scipy.misc.imresize convert to uint8 type 78 | imgs_array = np.expand_dims(np.asarray(imgs_32).astype(np.float32), axis=3) 79 | 80 | return imgs_array / 127.5 - 1. # from [0., 255.] to [-1., 1.] 81 | 82 | 83 | class Cifar10(object): 84 | def __init__(self, flags, dataset_name): 85 | self.flags = flags 86 | self.dataset_name = dataset_name 87 | self.image_size = (32, 32, 3) 88 | self.num_trains = 0 89 | 90 | self.cifar10_path = os.path.join('../../Data', self.dataset_name) 91 | self._load_cifar10() 92 | 93 | def _load_cifar10(self): 94 | import cifar10 95 | 96 | cifar10.data_path = self.cifar10_path 97 | logger.info('Load {} dataset...'.format(self.dataset_name)) 98 | 99 | # The CIFAR-10 data-set is about 163 MB and will be downloaded automatically if it is not 100 | # located in the given path. 101 | cifar10.maybe_download_and_extract() 102 | 103 | self.train_data, _, _ = cifar10.load_training_data() 104 | self.num_trains = self.train_data.shape[0] 105 | 106 | logger.info('Load {} dataset SUCCESS!'.format(self.dataset_name)) 107 | logger.info('Img size: {}'.format(self.image_size)) 108 | logger.info('Num. of training data: {}'.format(self.num_trains)) 109 | 110 | def train_next_batch(self, batch_size): 111 | batch_imgs = self.train_data[np.random.choice(self.num_trains, batch_size, replace=False)] 112 | return batch_imgs * 2. - 1. # from [0., 1.] to [-1., 1.] 113 | 114 | 115 | class ImageNet64(object): 116 | def __init__(self, flags, dataset_name): 117 | self.flags = flags 118 | self.dataset_name = dataset_name 119 | self.image_size = (64, 64, 3) 120 | self.num_trains = 0 121 | 122 | self.imagenet64_path = os.path.join('../../Data', self.dataset_name, 'train_64x64') 123 | self._load_imagenet64() 124 | 125 | def _load_imagenet64(self): 126 | logger.info('Load {} dataset...'.format(self.dataset_name)) 127 | self.train_data = utils.all_files_under(self.imagenet64_path, extension='.png') 128 | self.num_trains = len(self.train_data) 129 | 130 | logger.info('Load {} dataset SUCCESS!'.format(self.dataset_name)) 131 | logger.info('Img size: {}'.format(self.image_size)) 132 | logger.info('Num. of training data: {}'.format(self.num_trains)) 133 | 134 | def train_next_batch(self, batch_size): 135 | batch_paths = np.random.choice(self.train_data, batch_size, replace=False) 136 | batch_imgs = [utils.load_data(batch_path, is_gray_scale=False) for batch_path in batch_paths] 137 | return np.asarray(batch_imgs) 138 | 139 | 140 | # noinspection PyPep8Naming 141 | def Dataset(sess, flags, dataset_name, log_path=None): 142 | if flags.is_train: 143 | _init_logger(flags, log_path) # init logger 144 | 145 | if dataset_name == 'mnist': 146 | return MnistDataset(sess, flags, dataset_name) 147 | elif dataset_name == 'cifar10': 148 | return Cifar10(flags, dataset_name) 149 | elif dataset_name == 'imagenet64': 150 | return ImageNet64(flags, dataset_name) 151 | else: 152 | raise NotImplementedError 153 | -------------------------------------------------------------------------------- /src/download.py: -------------------------------------------------------------------------------- 1 | ######################################################################## 2 | # 3 | # Functions for downloading and extracting data-files from the internet. 4 | # 5 | # Implemented in Python 3.5 6 | # 7 | ######################################################################## 8 | # 9 | # This file is part of the TensorFlow Tutorials available at: 10 | # 11 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials 12 | # 13 | # Published under the MIT License. See the file LICENSE for details. 14 | # 15 | # Copyright 2016 by Magnus Erik Hvass Pedersen 16 | # 17 | ######################################################################## 18 | 19 | import sys 20 | import os 21 | import urllib.request 22 | import tarfile 23 | import zipfile 24 | 25 | ######################################################################## 26 | 27 | 28 | def _print_download_progress(count, block_size, total_size): 29 | """ 30 | Function used for printing the download progress. 31 | Used as a call-back function in maybe_download_and_extract(). 32 | """ 33 | 34 | # Percentage completion. 35 | pct_complete = float(count * block_size) / total_size 36 | 37 | # Status-message. Note the \r which means the line should overwrite itself. 38 | msg = "\r- Download progress: {0:.1%}".format(pct_complete) 39 | 40 | # Print it. 41 | sys.stdout.write(msg) 42 | sys.stdout.flush() 43 | 44 | 45 | ######################################################################## 46 | 47 | 48 | def maybe_download_and_extract(url, download_dir): 49 | """ 50 | Download and extract the data if it doesn't already exist. 51 | Assumes the url is a tar-ball file. 52 | :param url: 53 | Internet URL for the tar-file to download. 54 | Example: "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 55 | :param download_dir: 56 | Directory where the downloaded file is saved. 57 | Example: "data/CIFAR-10/" 58 | :return: 59 | Nothing. 60 | """ 61 | 62 | # Filename for saving the file downloaded from the internet. 63 | # Use the filename from the URL and add it to the download_dir. 64 | filename = url.split('/')[-1] 65 | file_path = os.path.join(download_dir, filename) 66 | 67 | # Check if the file already exists. 68 | # If it exists then we assume it has also been extracted, 69 | # otherwise we need to download and extract it now. 70 | if not os.path.exists(file_path): 71 | # Check if the download directory exists, otherwise create it. 72 | if not os.path.exists(download_dir): 73 | os.makedirs(download_dir) 74 | 75 | # Download the file from the internet. 76 | file_path, _ = urllib.request.urlretrieve(url=url, 77 | filename=file_path, 78 | reporthook=_print_download_progress) 79 | 80 | print() 81 | print("Download finished. Extracting files.") 82 | 83 | if file_path.endswith(".zip"): 84 | # Unpack the zip-file. 85 | zipfile.ZipFile(file=file_path, mode="r").extractall(download_dir) 86 | elif file_path.endswith((".tar.gz", ".tgz")): 87 | # Unpack the tar-ball. 88 | tarfile.open(name=file_path, mode="r:gz").extractall(download_dir) 89 | 90 | print("Done.") 91 | else: 92 | print("Data has apparently already been downloaded and unpacked.") 93 | 94 | 95 | ######################################################################## 96 | -------------------------------------------------------------------------------- /src/download2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py 3 | 4 | Downloads the following: 5 | - Celeb-A dataset 6 | - LSUN dataset 7 | - MNIST dataset 8 | """ 9 | 10 | import os 11 | import sys 12 | import json 13 | import zipfile 14 | import argparse 15 | import requests 16 | import subprocess 17 | from tqdm import tqdm 18 | from six.moves import urllib 19 | 20 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') 21 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'], 22 | help='name of dataset to download [celebA, lsun, mnist]') 23 | 24 | 25 | def download(url, dirpath): 26 | filename = url.split('/')[-1] 27 | filepath = os.path.join(dirpath, filename) 28 | u = urllib.request.urlopen(url) 29 | f = open(filepath, 'wb') 30 | filesize = int(u.headers["Content-Length"]) 31 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 32 | 33 | downloaded = 0 34 | block_sz = 8192 35 | status_width = 70 36 | 37 | while True: 38 | buf = u.read(block_sz) 39 | if not buf: 40 | print('') 41 | break 42 | else: 43 | print('', end='\r') 44 | downloaded += len(buf) 45 | f.write(buf) 46 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 47 | ('=' * int(float(downloaded) / filesize * status_width) + '>', 48 | downloaded * 100. / filesize)) 49 | print(status, end='') 50 | sys.stdout.flush() 51 | f.close() 52 | 53 | return filepath 54 | 55 | 56 | def download_file_from_google_drive(id_, destination): 57 | url = "https://docs.google.com/uc?export=download" 58 | session = requests.Session() 59 | 60 | response = session.get(url, params={'id': id_}, stream=True) 61 | token = get_confirm_token(response) 62 | 63 | if token: 64 | params = {'id': id_, 'confirm': token} 65 | response = session.get(url, params=params, stream=True) 66 | 67 | save_response_content(response, destination) 68 | 69 | 70 | def get_confirm_token(response): 71 | for key, value in response.cookies.items(): 72 | if key.startswith('download_warning'): 73 | return value 74 | 75 | return None 76 | 77 | 78 | def save_response_content(response, destination, chunk_size=32*1024): 79 | total_size = int(response.headers.get('content-length', 0)) 80 | with open(destination, "wb") as f: 81 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, unit='B', 82 | unit_scale=True, desc=destination): 83 | if chunk: # filter out keep-alive new chunks 84 | f.write(chunk) 85 | 86 | 87 | def unzip(filepath): 88 | print("Extracting: " + filepath) 89 | dirpath = os.path.dirname(filepath) 90 | with zipfile.ZipFile(filepath) as zf: 91 | zf.extractall(dirpath) 92 | os.remove(filepath) 93 | 94 | 95 | def download_celeb_a(dirpath): 96 | data_dir = 'celebA' 97 | if os.path.exists(os.path.join(dirpath, data_dir)): 98 | print('Found Celeb-A - skip') 99 | return 100 | 101 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 102 | save_path = os.path.join(dirpath, filename) 103 | 104 | if os.path.exists(save_path): 105 | print('[*] {} already exists'.format(save_path)) 106 | else: 107 | download_file_from_google_drive(drive_id, save_path) 108 | 109 | # zip_dir = '' 110 | with zipfile.ZipFile(save_path) as zf: 111 | zip_dir = zf.namelist()[0] 112 | zf.extractall(dirpath) 113 | os.remove(save_path) 114 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 115 | 116 | 117 | def _list_categories(tag): 118 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag 119 | f = urllib.request.urlopen(url) 120 | 121 | return json.loads(f.read()) 122 | 123 | 124 | def _download_lsun(out_dir, category, set_name, tag): 125 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 126 | '&category={category}&set={set_name}'.format(**locals()) 127 | print(url) 128 | if set_name == 'test': 129 | out_name = 'test_lmdb.zip' 130 | else: 131 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) 132 | 133 | out_path = os.path.join(out_dir, out_name) 134 | cmd = ['curl', url, '-o', out_path] 135 | print('Downloading', category, set_name, 'set') 136 | subprocess.call(cmd) 137 | 138 | 139 | def download_lsun(dirpath): 140 | data_dir = os.path.join(dirpath, 'lsun') 141 | if os.path.exists(data_dir): 142 | print('Found LSUN - skip') 143 | return 144 | else: 145 | os.mkdir(data_dir) 146 | 147 | tag = 'latest' 148 | # categories = _list_categories(tag) 149 | categories = ['bedroom'] 150 | 151 | for category in categories: 152 | _download_lsun(data_dir, category, 'train', tag) 153 | _download_lsun(data_dir, category, 'val', tag) 154 | _download_lsun(data_dir, '', 'test', tag) 155 | 156 | 157 | def download_mnist(dirpath): 158 | data_dir = os.path.join(dirpath, 'mnist') 159 | if os.path.exists(data_dir): 160 | print('Found MNIST - skip') 161 | return 162 | else: 163 | os.mkdir(data_dir) 164 | 165 | url_base = 'http://yann.lecun.com/exdb/mnist/' 166 | file_names = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', 167 | 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz'] 168 | 169 | for file_name in file_names: 170 | url = (url_base+file_name).format(**locals()) 171 | print(url) 172 | out_path = os.path.join(data_dir, file_name) 173 | cmd = ['curl', url, '-o', out_path] 174 | print('Downloading ', file_name) 175 | subprocess.call(cmd) 176 | cmd = ['gzip', '-d', out_path] 177 | print('Decompressing ', file_name) 178 | subprocess.call(cmd) 179 | 180 | 181 | def prepare_data_dir(path='./data'): 182 | if not os.path.exists(path): 183 | os.mkdir(path) 184 | 185 | 186 | if __name__ == '__main__': 187 | args = parser.parse_args() 188 | prepare_data_dir() 189 | 190 | if any(name in args.datasets for name in ['CelebA', 'celebA', 'celeba']): 191 | download_celeb_a('./data') # download celebeA dataset 192 | if 'lsun' in args.datasets: 193 | download_lsun('./data') # download LSUN bedroom dataset 194 | if 'mnist' in args.datasets: 195 | download_mnist('./data') # download mnist dataset 196 | -------------------------------------------------------------------------------- /src/inception_score.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow Inception Score Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # From https://github.com/openai/improved-gan/blob/master/inception_score/model.py 5 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 6 | # Reimplement by Cheng-Bin Jin 7 | # Email: sbkim0407@gmail.com 8 | # --------------------------------------------------------- 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import os.path 14 | import tarfile 15 | 16 | import numpy as np 17 | from six.moves import urllib 18 | import tensorflow as tf 19 | import math 20 | import sys 21 | 22 | MODEL_DIR = './imagenet' 23 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 24 | softmax = None 25 | 26 | 27 | # Call this function with list of images. Each of elements should be a numpy array with values ranging from 0 to 255. 28 | def get_inception_score(images, flags, splits=10): 29 | os.environ['CUDA_VISIBLE_DEVICES'] = flags.gpu_index 30 | 31 | assert(type(images) == list) 32 | assert(type(images[0]) == np.ndarray) 33 | assert(len(images[0].shape) == 3) 34 | assert(np.max(images[0]) > 10) 35 | assert(np.min(images[0]) >= 0.0) 36 | 37 | inps = [] 38 | for img in images: 39 | img = img.astype(np.float32) 40 | inps.append(np.expand_dims(img, 0)) 41 | 42 | bs = 1 # original is 100 43 | with tf.Session() as sess: 44 | preds = [] 45 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 46 | 47 | for i in range(n_batches): 48 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 49 | inp = np.concatenate(inp, 0) 50 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 51 | preds.append(pred) 52 | 53 | preds = np.concatenate(preds, 0) 54 | scores = [] 55 | 56 | for i in range(splits): 57 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 58 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, axis=0), axis=0))) 59 | kl = np.mean(np.sum(kl, axis=1)) 60 | scores.append(np.exp(kl)) 61 | 62 | return np.mean(scores), np.std(scores) 63 | 64 | 65 | # This function is called automatically. 66 | def _init_inception(): 67 | global softmax 68 | if not os.path.exists(MODEL_DIR): 69 | os.makedirs(MODEL_DIR) 70 | 71 | # DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 72 | filename = DATA_URL.split('/')[-1] 73 | filepath = os.path.join(MODEL_DIR, filename) 74 | 75 | if not os.path.exists(filepath): 76 | def _progress(count, block_size, total_size): 77 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 78 | filename, float(count * block_size) / float(total_size) * 100.0)) 79 | sys.stdout.flush() 80 | 81 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 82 | print() 83 | statinfo = os.stat(filepath) 84 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 85 | 86 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 87 | with tf.gfile.FastGFile(os.path.join(MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 88 | graph_def = tf.GraphDef() 89 | graph_def.ParseFromString(f.read()) 90 | _ = tf.import_graph_def(graph_def, name='') 91 | 92 | # Works with an arbitrary minibatch size. 93 | with tf.Session() as sess: 94 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 95 | ops = pool3.graph.get_operations() 96 | 97 | for op_idx, op in enumerate(ops): 98 | for o in op.outputs: 99 | shape = o.get_shape() 100 | shape = [s.value for s in shape] 101 | new_shape = [] 102 | 103 | for j, s in enumerate(shape): 104 | if s == 1 and j == 0: 105 | new_shape.append(None) 106 | else: 107 | new_shape.append(s) 108 | 109 | o.set_shape = tf.TensorShape(new_shape) 110 | 111 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 112 | logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w) 113 | softmax = tf.nn.softmax(logits) 114 | 115 | 116 | if softmax is None: 117 | _init_inception() 118 | -------------------------------------------------------------------------------- /src/jupyter/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /src/jupyter/gan_toy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import os\n", 12 | "import random\n", 13 | "import numpy as np\n", 14 | "import sklearn.datasets\n", 15 | "import tensorflow as tf\n", 16 | "import tensorflow_utils as tf_utils\n", 17 | "import utils as utils\n", 18 | "\n", 19 | "import matplotlib\n", 20 | "matplotlib.use('Agg')\n", 21 | "import matplotlib.pyplot as plt" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "Uppercase local vars:\n", 34 | "\tBATCH_SIZE: 256\n", 35 | "\tCRITIC_ITERS: 1\n", 36 | "\tDIM: 512\n", 37 | "\tDIS_DIM: 1\n", 38 | "\tFIXED_GENERATOR: False\n", 39 | "\tFREQ: 1000\n", 40 | "\tGEN_DIM: 2\n", 41 | "\tITERS: 100000\n", 42 | "\tLAMBDA: 0.1\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "DIM = 512 # model dimensionality\n", 48 | "GEN_DIM = 2 # output dimension of the generator\n", 49 | "DIS_DIM = 1 # outptu dimension fo the discriminator\n", 50 | "FIXED_GENERATOR = False # wheter to hold the generator fixed at ral data plus Gaussian noise, as in the plots in the paper\n", 51 | "LAMBDA = .1 # smaller lambda makes things faster for toy tasks, but isn't necessary if you increase CRITIC_ITERS enough\n", 52 | "BATCH_SIZE = 256 # batch size\n", 53 | "ITERS = 100000 # how many generator iterations to train for\n", 54 | "FREQ = 1000 # sample frequency\n", 55 | "\n", 56 | "mode = 'wgan-g' # [gan, wgan, wgan-gp]\n", 57 | "dataset = 'swissroll' # [8gaussians, 25gaussians, swissroll]\n", 58 | "img_folder = os.path.join('img', mode + '_' + dataset + '_' + str(FIXED_GENERATOR))\n", 59 | "\n", 60 | "if mode == 'gan':\n", 61 | " CRITIC_ITERS = 1 # homw many critic iteractions per generator iteration\n", 62 | "else:\n", 63 | " CRITIC_ITERS = 5 # homw many critic iteractions per generator iteration\n", 64 | "\n", 65 | "if not os.path.isdir(img_folder):\n", 66 | " os.makedirs(img_folder)\n", 67 | "\n", 68 | "utils.print_model_setting(locals().copy())" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 3, 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "gen/fc-1/add [256, 512]\n", 81 | "gen/fc-2/add [256, 512]\n", 82 | "gen/fc-3/add [256, 512]\n", 83 | "gen/fc-4/add [256, 2]\n", 84 | "is_reuse: False\n", 85 | "disc/fc-1/add [None, 512]\n", 86 | "disc/fc-2/add [None, 512]\n", 87 | "disc/fc-3/add [None, 512]\n", 88 | "disc/fc-4/add [None, 1]\n", 89 | "is_reuse: True\n", 90 | "disc_1/fc-1/add [256, 512]\n", 91 | "disc_1/fc-2/add [256, 512]\n", 92 | "disc_1/fc-3/add [256, 512]\n", 93 | "disc_1/fc-4/add [256, 1]\n" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "def Generator(n_samples, real_data_, name='gen'):\n", 99 | " if FIXED_GENERATOR:\n", 100 | " return real_data_ + (1. * tf.random_normal(tf.shape(real_data_)))\n", 101 | " else:\n", 102 | " with tf.variable_scope(name):\n", 103 | " noise = tf.random_normal([n_samples, 2])\n", 104 | " output01 = tf_utils.linear(noise, DIM, name='fc-1')\n", 105 | " output01 = tf_utils.relu(output01, name='relu-1')\n", 106 | " \n", 107 | " output02 = tf_utils.linear(output01, DIM, name='fc-2')\n", 108 | " output02 = tf_utils.relu(output02, name='relu-2')\n", 109 | " \n", 110 | " output03 = tf_utils.linear(output02, DIM, name='fc-3')\n", 111 | " output03 = tf_utils.relu(output03, name='relu-3')\n", 112 | " \n", 113 | " output04 = tf_utils.linear(output03, GEN_DIM, name='fc-4')\n", 114 | " \n", 115 | " return output04\n", 116 | " \n", 117 | "\n", 118 | "def Discriminator(inputs, is_reuse=True, name='disc'):\n", 119 | " with tf.variable_scope(name, reuse=is_reuse):\n", 120 | " print('is_reuse: {}'.format(is_reuse))\n", 121 | " output01 = tf_utils.linear(inputs, DIM, name='fc-1')\n", 122 | " output01 = tf_utils.relu(output01, name='relu-1')\n", 123 | "\n", 124 | " output02 = tf_utils.linear(output01, DIM, name='fc-2')\n", 125 | " output02 = tf_utils.relu(output02, name='relu-2')\n", 126 | "\n", 127 | " output03 = tf_utils.linear(output02, DIM, name='fc-3')\n", 128 | " output03 = tf_utils.relu(output03, name='relu-3')\n", 129 | "\n", 130 | " output04 = tf_utils.linear(output03, DIS_DIM, name='fc-4')\n", 131 | " \n", 132 | " return output04\n", 133 | " \n", 134 | "real_data = tf.placeholder(tf.float32, shape=[None, 2])\n", 135 | "fake_data = Generator(BATCH_SIZE, real_data)\n", 136 | "\n", 137 | "disc_real = Discriminator(real_data, is_reuse=False)\n", 138 | "disc_fake = Discriminator(fake_data)\n", 139 | "\n", 140 | "if mode == 'wgan' or mode == 'wgan-gp':\n", 141 | " # WGAN loss\n", 142 | " disc_cost = tf.reduce_mean(disc_fake) - tf.reduce_mean(disc_real)\n", 143 | " gen_cost = - tf.reduce_mean(disc_fake)\n", 144 | "\n", 145 | " # WGAN gradient penalty\n", 146 | " if mode == 'wgan-gp':\n", 147 | " alpha = tf.random_uniform(shape=[BATCH_SIZE, 1], minval=0., maxval=1.)\n", 148 | " interpolates = alpha*real_data + (1.-alpha) * fake_data\n", 149 | " disc_interpolates = Discriminator(interpolates)\n", 150 | " gradients = tf.gradients(disc_interpolates, [interpolates][0])\n", 151 | " slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))\n", 152 | " gradient_penalty = tf.reduce_mean((slopes - 1)**2)\n", 153 | " \n", 154 | " disc_cost += LAMBDA * gradient_penalty\n", 155 | "elif mode == 'gan':\n", 156 | " gen_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=tf.ones_like(disc_fake)))\n", 157 | " \n", 158 | " disc_cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=tf.zeros_like(disc_fake)))\n", 159 | " disc_cost += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=tf.ones_like(disc_real)))\n", 160 | " disc_cost /= 2.\n", 161 | " \n", 162 | "disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='disc')\n", 163 | "gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='gen')\n", 164 | "\n", 165 | "if mode == 'wgan-gp':\n", 166 | " disc_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(disc_cost, var_list=disc_vars)\n", 167 | " \n", 168 | " if len(gen_vars) > 0:\n", 169 | " gen_train_op = tf.train.AdamOptimizer(learning_rate=1e-4, beta1=0.5, beta2=0.9).minimize(gen_cost, var_list=gen_vars)\n", 170 | " else:\n", 171 | " gen_train_op = tf.no_op()\n", 172 | " \n", 173 | "elif mode == 'wgan':\n", 174 | " disc_train_op = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(disc_cost, var_list=disc_vars)\n", 175 | " \n", 176 | " if len(gen_vars) > 0:\n", 177 | " gen_train_op = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(gen_cost, var_list=gen_vars)\n", 178 | " else:\n", 179 | " gen_train_op = tf.no_op()\n", 180 | " \n", 181 | " # build an op to do the weight clipping\n", 182 | " clip_bounds = [-0.01, 0.01]\n", 183 | " clip_ops = [var.assign(tf.clip_by_value(var, clip_bounds[0], clip_bounds[1])) for var in disc_vars]\n", 184 | " clip_disc_weights = tf.group(*clip_ops)\n", 185 | " \n", 186 | "elif mode == 'gan':\n", 187 | " disc_train_op = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5).minimize(disc_cost, var_list=disc_vars)\n", 188 | " \n", 189 | " if len(gen_vars) > 0: \n", 190 | " gen_train_op = tf.train.AdamOptimizer(learning_rate=2e-4, beta1=0.5).minimize(gen_cost, var_list=gen_vars)\n", 191 | " else:\n", 192 | " gen_train_op = tf.no_op()" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 4, 198 | "metadata": { 199 | "collapsed": true 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "def generate_image(sess, true_dist, idx):\n", 204 | " # generates and saves a plot of the true distribution, the generator, and the critic\n", 205 | " N_POINTS = 128\n", 206 | " RANGE = 2\n", 207 | " \n", 208 | " points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')\n", 209 | " points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]\n", 210 | " points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]\n", 211 | " points = points.reshape((-1, 2))\n", 212 | " \n", 213 | " if FIXED_GENERATOR is not True:\n", 214 | " samples = sess.run(fake_data, feed_dict={real_data: points})\n", 215 | " disc_map = sess.run(disc_real, feed_dict={real_data: points})\n", 216 | " \n", 217 | " plt.clf()\n", 218 | " x = y = np.linspace(-RANGE, RANGE, N_POINTS)\n", 219 | " plt.contour(x, y, disc_map.reshape((len(x), len(y))).transpose())\n", 220 | " plt.colorbar() # add color bar\n", 221 | " \n", 222 | " plt.scatter(true_dist[:, 0], true_dist[:, 1], c='orange', marker='+')\n", 223 | " if FIXED_GENERATOR is not True:\n", 224 | " plt.scatter(samples[:, 0], samples[:, 1], c='green', marker='*')\n", 225 | " \n", 226 | " plt.savefig(os.path.join(img_folder, str(idx).zfill(3) + '.jpg'))" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 5, 232 | "metadata": { 233 | "collapsed": true 234 | }, 235 | "outputs": [], 236 | "source": [ 237 | "# Dataset iterator\n", 238 | "def inf_train_gen():\n", 239 | " if dataset == '8gaussians':\n", 240 | " scale = 2.\n", 241 | " centers = [(1.,0.), \n", 242 | " (-1.,0.), \n", 243 | " (0., 1.), \n", 244 | " (0.,-1.),\n", 245 | " (1./np.sqrt(2), 1./np.sqrt(2)),\n", 246 | " (1./np.sqrt(2), -1/np.sqrt(2)), \n", 247 | " (-1./np.sqrt(2), 1./np.sqrt(2)), \n", 248 | " (-1./np.sqrt(2), -1./np.sqrt(2))]\n", 249 | " \n", 250 | " centers = [(scale*x, scale*y) for x, y in centers]\n", 251 | " while True:\n", 252 | " batch_data = []\n", 253 | " for _ in range(BATCH_SIZE):\n", 254 | " point = np.random.randn(2) * .02\n", 255 | " center = random.choice(centers)\n", 256 | " point[0] += center[0]\n", 257 | " point[1] += center[1]\n", 258 | " batch_data.append(point)\n", 259 | " \n", 260 | " batch_data = np.array(batch_data, dtype=np.float32)\n", 261 | " batch_data /= 1.414 # std\n", 262 | " yield batch_data\n", 263 | " \n", 264 | " elif dataset == '25gaussians':\n", 265 | " batch_data = []\n", 266 | " for i_ in range(4000):\n", 267 | " for x in range(-2, 3):\n", 268 | " for y in range(-2, 3):\n", 269 | " point = np.random.randn(2) * 0.05\n", 270 | " point[0] += 2*x\n", 271 | " point[1] += 2*y\n", 272 | " batch_data.append(point)\n", 273 | " \n", 274 | " batch_data = np.asarray(batch_data, dtype=np.float32)\n", 275 | " np.random.shuffle(batch_data)\n", 276 | " batch_data /= 2.828 # std\n", 277 | " \n", 278 | " while True:\n", 279 | " for i_ in range(int(len(batch_data)/BATCH_SIZE)):\n", 280 | " yield batch_data[i_*BATCH_SIZE:(i_+1)*BATCH_SIZE]\n", 281 | " \n", 282 | " elif dataset == 'swissroll':\n", 283 | " while True:\n", 284 | " batch_data = sklearn.datasets.make_swiss_roll(n_samples=BATCH_SIZE, noise=0.25)[0]\n", 285 | " batch_data = batch_data.astype(np.float32)[:, [0, 2]]\n", 286 | " batch_data /= 7.5 # stdev plus a little\n", 287 | " yield batch_data" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 6, 293 | "metadata": {}, 294 | "outputs": [ 295 | { 296 | "name": "stdout", 297 | "output_type": "stream", 298 | "text": [ 299 | "iter 0\tdisc cost\t0.7139465808868408\n", 300 | "iter 1000\tdisc cost\t0.6814616771042347\n", 301 | "iter 2000\tdisc cost\t0.6737550102472305\n", 302 | "iter 3000\tdisc cost\t0.6920662943720818\n", 303 | "iter 4000\tdisc cost\t0.6906127047538757\n", 304 | "iter 5000\tdisc cost\t0.6930880607366562\n", 305 | "iter 6000\tdisc cost\t0.6931662155985833\n", 306 | "iter 7000\tdisc cost\t0.6919992100596428\n", 307 | "iter 8000\tdisc cost\t0.6933745203018189\n", 308 | "iter 9000\tdisc cost\t0.6926324440836906\n", 309 | "iter 10000\tdisc cost\t0.6933235999941826\n", 310 | "iter 11000\tdisc cost\t0.6933176911473274\n", 311 | "iter 12000\tdisc cost\t0.6923655412197113\n", 312 | "iter 13000\tdisc cost\t0.6932780802249908\n", 313 | "iter 14000\tdisc cost\t0.6932716737985611\n", 314 | "iter 15000\tdisc cost\t0.6932756861448288\n", 315 | "iter 16000\tdisc cost\t0.6915568546652794\n", 316 | "iter 17000\tdisc cost\t0.6933772000074386\n", 317 | "iter 18000\tdisc cost\t0.6933294381499291\n", 318 | "iter 19000\tdisc cost\t0.6924007230997086\n", 319 | "iter 20000\tdisc cost\t0.6928844445347786\n", 320 | "iter 21000\tdisc cost\t0.6933696148395538\n", 321 | "iter 22000\tdisc cost\t0.6933540980815888\n", 322 | "iter 23000\tdisc cost\t0.6933285065889359\n", 323 | "iter 24000\tdisc cost\t0.6918070427775384\n", 324 | "iter 25000\tdisc cost\t0.6933358488082886\n", 325 | "iter 26000\tdisc cost\t0.6933039631843567\n", 326 | "iter 27000\tdisc cost\t0.6932847768068313\n", 327 | "iter 28000\tdisc cost\t0.6932973915338516\n", 328 | "iter 29000\tdisc cost\t0.6932247024178505\n", 329 | "iter 30000\tdisc cost\t0.6932735745310783\n", 330 | "iter 31000\tdisc cost\t0.6932096065282821\n", 331 | "iter 32000\tdisc cost\t0.6890769819617272\n", 332 | "iter 33000\tdisc cost\t0.6923709976673126\n", 333 | "iter 34000\tdisc cost\t0.6931670038104057\n", 334 | "iter 35000\tdisc cost\t0.6929483571052552\n", 335 | "iter 36000\tdisc cost\t0.6932186397910118\n", 336 | "iter 37000\tdisc cost\t0.6932798113822937\n", 337 | "iter 38000\tdisc cost\t0.6932742780447007\n", 338 | "iter 39000\tdisc cost\t0.6931445201039315\n", 339 | "iter 40000\tdisc cost\t0.6932249497771263\n", 340 | "iter 41000\tdisc cost\t0.6932100749611855\n", 341 | "iter 42000\tdisc cost\t0.6929260883331299\n", 342 | "iter 43000\tdisc cost\t0.6932117487192154\n", 343 | "iter 44000\tdisc cost\t0.6931842603087425\n", 344 | "iter 45000\tdisc cost\t0.6931332590579986\n", 345 | "iter 46000\tdisc cost\t0.6931791034936905\n", 346 | "iter 47000\tdisc cost\t0.693181702375412\n", 347 | "iter 48000\tdisc cost\t0.6931661418676376\n", 348 | "iter 49000\tdisc cost\t0.6932356751561165\n", 349 | "iter 50000\tdisc cost\t0.6931068043708801\n", 350 | "iter 51000\tdisc cost\t0.6931518998742103\n", 351 | "iter 52000\tdisc cost\t0.6929512389302254\n", 352 | "iter 53000\tdisc cost\t0.69314514118433\n", 353 | "iter 54000\tdisc cost\t0.6931488544344903\n", 354 | "iter 55000\tdisc cost\t0.6926609833240509\n", 355 | "iter 56000\tdisc cost\t0.6931225464940071\n", 356 | "iter 57000\tdisc cost\t0.693101442694664\n", 357 | "iter 58000\tdisc cost\t0.6929626023769379\n", 358 | "iter 59000\tdisc cost\t0.6930847532749176\n", 359 | "iter 60000\tdisc cost\t0.6928317602872849\n", 360 | "iter 61000\tdisc cost\t0.6930897635817528\n", 361 | "iter 62000\tdisc cost\t0.6931442449092865\n", 362 | "iter 63000\tdisc cost\t0.6931378693580628\n", 363 | "iter 64000\tdisc cost\t0.6931423845887185\n", 364 | "iter 65000\tdisc cost\t0.6929416499137878\n", 365 | "iter 66000\tdisc cost\t0.6930960271954536\n", 366 | "iter 67000\tdisc cost\t0.6931367915272713\n", 367 | "iter 68000\tdisc cost\t0.6931659261584282\n", 368 | "iter 69000\tdisc cost\t0.6931650475263595\n", 369 | "iter 70000\tdisc cost\t0.6930810611844063\n", 370 | "iter 71000\tdisc cost\t0.6931314595341682\n", 371 | "iter 72000\tdisc cost\t0.6931117633581162\n", 372 | "iter 73000\tdisc cost\t0.6931591765880585\n", 373 | "iter 74000\tdisc cost\t0.6929754077196121\n", 374 | "iter 75000\tdisc cost\t0.693103608250618\n", 375 | "iter 76000\tdisc cost\t0.6931282907724381\n", 376 | "iter 77000\tdisc cost\t0.6930146722197532\n", 377 | "iter 78000\tdisc cost\t0.6930771964788437\n", 378 | "iter 79000\tdisc cost\t0.6931223605871201\n", 379 | "iter 80000\tdisc cost\t0.693151865184307\n", 380 | "iter 81000\tdisc cost\t0.693146475315094\n", 381 | "iter 82000\tdisc cost\t0.6931274979114532\n", 382 | "iter 83000\tdisc cost\t0.6924246903657914\n", 383 | "iter 84000\tdisc cost\t0.6930215736031532\n", 384 | "iter 85000\tdisc cost\t0.6930515897274018\n", 385 | "iter 86000\tdisc cost\t0.6930611334443092\n", 386 | "iter 87000\tdisc cost\t0.6926919516324997\n", 387 | "iter 88000\tdisc cost\t0.6930547666549682\n", 388 | "iter 89000\tdisc cost\t0.6929932239055634\n", 389 | "iter 90000\tdisc cost\t0.6930882499814034\n", 390 | "iter 91000\tdisc cost\t0.6930873312354088\n", 391 | "iter 92000\tdisc cost\t0.6923705233931542\n", 392 | "iter 93000\tdisc cost\t0.6930174446702003\n", 393 | "iter 94000\tdisc cost\t0.6927620051503182\n", 394 | "iter 95000\tdisc cost\t0.69283284419775\n", 395 | "iter 96000\tdisc cost\t0.6930583512187004\n", 396 | "iter 97000\tdisc cost\t0.6930861786007881\n", 397 | "iter 98000\tdisc cost\t0.6927603552937508\n", 398 | "iter 99000\tdisc cost\t0.6929180439114571\n", 399 | "iter 99999\tdisc cost\t0.6930565155662216\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "# Train loop\n", 405 | "with tf.Session() as sess:\n", 406 | " sess.run(tf.global_variables_initializer())\n", 407 | " data_gen = inf_train_gen()\n", 408 | " \n", 409 | " for iter_ in range(ITERS):\n", 410 | " batch_data, disc_cost_ = None, None\n", 411 | " \n", 412 | " # train critic\n", 413 | " for i_ in range(CRITIC_ITERS):\n", 414 | " batch_data = data_gen.__next__()\n", 415 | " disc_cost_, _ = sess.run([disc_cost, disc_train_op], feed_dict={real_data: batch_data})\n", 416 | " \n", 417 | " if mode == 'wgan':\n", 418 | " sess.run(clip_disc_weights)\n", 419 | " \n", 420 | " # train generator\n", 421 | " sess.run(gen_train_op)\n", 422 | " \n", 423 | " # write logs and svae samples\n", 424 | " utils.plot('disc cost', disc_cost_)\n", 425 | " \n", 426 | " if (np.mod(iter_, FREQ) == 0) or (iter_+1 == ITERS):\n", 427 | " utils.flush(img_folder)\n", 428 | " generate_image(sess, batch_data, iter_)\n", 429 | " \n", 430 | " utils.tick() " 431 | ] 432 | } 433 | ], 434 | "metadata": { 435 | "kernelspec": { 436 | "display_name": "Python 3", 437 | "language": "python", 438 | "name": "python3" 439 | }, 440 | "language_info": { 441 | "codemirror_mode": { 442 | "name": "ipython", 443 | "version": 3 444 | }, 445 | "file_extension": ".py", 446 | "mimetype": "text/x-python", 447 | "name": "python", 448 | "nbconvert_exporter": "python", 449 | "pygments_lexer": "ipython3", 450 | "version": "3.5.5" 451 | } 452 | }, 453 | "nbformat": 4, 454 | "nbformat_minor": 2 455 | } 456 | -------------------------------------------------------------------------------- /src/jupyter/record_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | def all_files_under(path, extension=None, special=None, append_path=True, sort=True): 7 | if append_path: 8 | if extension is None: 9 | filenames = [os.path.join(path, fname) for fname in os.listdir(path) if special not in fname] 10 | else: 11 | filenames = [os.path.join(path, fname) for fname in os.listdir(path) 12 | if (special not in fname) and (fname.endswith(extension))] 13 | else: 14 | if extension is None: 15 | filenames = [os.path.basename(fname) for fname in os.listdir(path) if special not in fname] 16 | else: 17 | filenames = [os.path.basename(fname) for fname in os.listdir(path) 18 | if (special not in fname) and (fname.endswith(extension))] 19 | 20 | if sort: 21 | filenames = sorted(filenames) 22 | 23 | return filenames 24 | 25 | 26 | def main(path_list, name_list): 27 | for idx_list, path in enumerate(path_list): 28 | gan_8gaussian_paths = all_files_under(os.path.join('img', path[0]), extension='jpg', special='disc') 29 | gan_25gaussian_paths = all_files_under(os.path.join('img', path[1]), extension='jpg', special='disc') 30 | gan_swissroll_paths = all_files_under(os.path.join('img', path[2]), extension='jpg', special='disc') 31 | 32 | wgan_8gaussian_paths = all_files_under(os.path.join('img', path[3]), extension='jpg', special='disc') 33 | wgan_25gaussian_paths = all_files_under(os.path.join('img', path[4]), extension='jpg', special='disc') 34 | wgan_swissroll_paths = all_files_under(os.path.join('img', path[5]), extension='jpg', special='disc') 35 | 36 | wgan_gp_8gaussian_paths = all_files_under(os.path.join('img', path[6]), extension='jpg', special='disc') 37 | wgan_gp_25gaussian_paths = all_files_under(os.path.join('img', path[7]), extension='jpg', special='disc') 38 | wgan_gp_swissroll_paths = all_files_under(os.path.join('img', path[8]), extension='jpg', special='disc') 39 | 40 | frame_shape = cv2.imread(wgan_8gaussian_paths[0]).shape 41 | print(frame_shape) 42 | 43 | # Define the codec and create VideoWriter object 44 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 45 | video_writer = cv2.VideoWriter(name_list[idx_list], fourcc, 10.0, (frame_shape[1]*3, frame_shape[0]*3)) 46 | 47 | for idx in range(len(wgan_8gaussian_paths)): 48 | img_gan8 = cv2.imread(gan_8gaussian_paths[idx]) 49 | img_gan25 = cv2.imread(gan_25gaussian_paths[idx]) 50 | img_gans = cv2.imread(gan_swissroll_paths[idx]) 51 | 52 | img_wgan8 = cv2.imread(wgan_8gaussian_paths[idx]) 53 | img_wgan25 = cv2.imread(wgan_25gaussian_paths[idx]) 54 | img_wgans = cv2.imread(wgan_swissroll_paths[idx]) 55 | 56 | img_wgangp8 = cv2.imread(wgan_gp_8gaussian_paths[idx]) 57 | img_wgangp25 = cv2.imread(wgan_gp_25gaussian_paths[idx]) 58 | img_wgangps = cv2.imread(wgan_gp_swissroll_paths[idx]) 59 | 60 | frame_1 = np.hstack([img_gan8, img_gan25, img_gans]) 61 | frame_2 = np.hstack([img_wgan8, img_wgan25, img_wgans]) 62 | frame_3 = np.hstack([img_wgangp8, img_wgangp25, img_wgangps]) 63 | frame = np.vstack([frame_1, frame_2, frame_3]) 64 | 65 | # write the frame 66 | video_writer.write(frame) 67 | 68 | cv2.imshow('Show', frame) 69 | cv2.waitKey(1) 70 | 71 | # Release everything if job is finished 72 | video_writer.release() 73 | cv2.destroyAllWindows() 74 | 75 | 76 | if __name__ == '__main__': 77 | path01 = ['gan_8gaussians_True', 'gan_25gaussians_True', 'gan_swissroll_True', 78 | 'wgan_8gaussians_True', 'wgan_25gaussians_True', 'wgan_swissroll_True', 79 | 'wgan-gp_8gaussians_True', 'wgan-gp_25gaussians_True', 'wgan-gp_swissroll_True'] 80 | path02 = ['gan_8gaussians_False', 'gan_25gaussians_False', 'gan_swissroll_False', 81 | 'wgan_8gaussians_False', 'wgan_25gaussians_False', 'wgan_swissroll_False', 82 | 'wgan-gp_8gaussians_False', 'wgan-gp_25gaussians_False', 'wgan-gp_swissroll_False'] 83 | file_names = ['generator_fixed_true.mp4', 'generator_fixed_false.mp4'] 84 | 85 | main([path01, path02], file_names) 86 | -------------------------------------------------------------------------------- /src/jupyter/tensorflow_utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow Utils Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import tensorflow as tf 8 | import tensorflow.contrib.slim as slim 9 | from tensorflow.python.training import moving_averages 10 | 11 | 12 | def padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='pad2d'): 13 | if pad_type == 'REFLECT': 14 | return tf.pad(x, [[0, 0], [p_h, p_h], [p_w, p_w], [0, 0]], 'REFLECT', name=name) 15 | 16 | 17 | def conv2d(x, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, padding='SAME', name='conv2d', is_print=True): 18 | with tf.variable_scope(name): 19 | w = tf.get_variable('w', [k_h, k_w, x.get_shape()[-1], output_dim], 20 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 21 | conv = tf.nn.conv2d(x, w, strides=[1, d_h, d_w, 1], padding=padding) 22 | 23 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 24 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 25 | conv = tf.nn.bias_add(conv, biases) 26 | 27 | if is_print: 28 | print_activations(conv) 29 | 30 | return conv 31 | 32 | 33 | def deconv2d(x, k, k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, padding_='SAME', output_size=None, 34 | name='deconv2d', with_w=False, is_print=True): 35 | with tf.variable_scope(name): 36 | input_shape = x.get_shape().as_list() 37 | 38 | # calculate output size 39 | h_output, w_output = None, None 40 | if not output_size: 41 | h_output, w_output = input_shape[1] * 2, input_shape[2] * 2 42 | # output_shape = [input_shape[0], h_output, w_output, k] # error when not define batch_size 43 | output_shape = [tf.shape(x)[0], h_output, w_output, k] 44 | 45 | # conv2d transpose 46 | w = tf.get_variable('w', [k_h, k_w, k, input_shape[3]], 47 | initializer=tf.random_normal_initializer(stddev=stddev)) 48 | deconv = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=[1, d_h, d_w, 1], 49 | padding=padding_) 50 | 51 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 52 | deconv = tf.nn.bias_add(deconv, biases) 53 | 54 | if is_print: 55 | print_activations(deconv) 56 | 57 | if with_w: 58 | return deconv, w, biases 59 | else: 60 | return deconv 61 | 62 | 63 | def upsampling2d(x, size=(2, 2), name='upsampling2d'): 64 | with tf.name_scope(name): 65 | shape = x.get_shape().as_list() 66 | return tf.image.resize_nearest_neighbor(x, size=(size[0] * shape[1], size[1] * shape[2])) 67 | 68 | 69 | def linear(x, output_size, bias_start=0.0, with_w=False, name='fc', is_print=True): 70 | shape = x.get_shape().as_list() 71 | 72 | with tf.variable_scope(name): 73 | matrix = tf.get_variable(name="matrix", shape=[shape[1], output_size], 74 | dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) 75 | bias = tf.get_variable(name="bias", shape=[output_size], 76 | initializer=tf.constant_initializer(bias_start)) 77 | output = tf.matmul(x, matrix) + bias 78 | 79 | if is_print: 80 | print_activations(output) 81 | 82 | if with_w: 83 | return output, matrix, bias 84 | else: 85 | return output 86 | 87 | 88 | def norm(x, name, _type, _ops, is_train=True): 89 | if _type == 'batch': 90 | return batch_norm(x, name=name, _ops=_ops, is_train=is_train) 91 | elif _type == 'instance': 92 | return instance_norm(x, name=name) 93 | else: 94 | raise NotImplementedError 95 | 96 | 97 | def batch_norm(x, name, _ops, is_train=True): 98 | """Batch normalization.""" 99 | with tf.variable_scope(name): 100 | params_shape = [x.get_shape()[-1]] 101 | 102 | beta = tf.get_variable('beta', params_shape, tf.float32, 103 | initializer=tf.constant_initializer(0.0, tf.float32)) 104 | gamma = tf.get_variable('gamma', params_shape, tf.float32, 105 | initializer=tf.constant_initializer(1.0, tf.float32)) 106 | 107 | if is_train is True: 108 | mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments') 109 | 110 | moving_mean = tf.get_variable('moving_mean', params_shape, tf.float32, 111 | initializer=tf.constant_initializer(0.0, tf.float32), 112 | trainable=False) 113 | moving_variance = tf.get_variable('moving_variance', params_shape, tf.float32, 114 | initializer=tf.constant_initializer(1.0, tf.float32), 115 | trainable=False) 116 | 117 | _ops.append(moving_averages.assign_moving_average(moving_mean, mean, 0.9)) 118 | _ops.append(moving_averages.assign_moving_average(moving_variance, variance, 0.9)) 119 | else: 120 | mean = tf.get_variable('moving_mean', params_shape, tf.float32, 121 | initializer=tf.constant_initializer(0.0, tf.float32), trainable=False) 122 | variance = tf.get_variable('moving_variance', params_shape, tf.float32, trainable=False) 123 | 124 | # epsilon used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net. 125 | y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-5) 126 | y.set_shape(x.get_shape()) 127 | 128 | return y 129 | 130 | 131 | def instance_norm(x, name='instance_norm', mean=1.0, stddev=0.02, epsilon=1e-5): 132 | with tf.variable_scope(name): 133 | depth = x.get_shape()[3] 134 | scale = tf.get_variable( 135 | 'scale', [depth], tf.float32, 136 | initializer=tf.random_normal_initializer(mean=mean, stddev=stddev, dtype=tf.float32)) 137 | offset = tf.get_variable('offset', [depth], initializer=tf.constant_initializer(0.0)) 138 | 139 | # calcualte mean and variance as instance 140 | mean, variance = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 141 | 142 | # normalization 143 | inv = tf.rsqrt(variance + epsilon) 144 | normalized = (x - mean) * inv 145 | 146 | return scale * normalized + offset 147 | 148 | 149 | def n_res_blocks(x, _ops=None, norm_='instance', is_train=True, num_blocks=6, is_print=False): 150 | output = None 151 | for idx in range(1, num_blocks+1): 152 | output = res_block(x, x.get_shape()[3], _ops=_ops, norm_=norm_, is_train=is_train, 153 | name='res{}'.format(idx)) 154 | x = output 155 | 156 | if is_print: 157 | print_activations(output) 158 | 159 | return output 160 | 161 | 162 | # norm(x, name, _type, _ops, is_train=True) 163 | def res_block(x, k, _ops=None, norm_='instance', is_train=True, pad_type=None, name=None): 164 | with tf.variable_scope(name): 165 | conv1, conv2 = None, None 166 | 167 | # 3x3 Conv-Batch-Relu S1 168 | with tf.variable_scope('layer1'): 169 | if pad_type is None: 170 | conv1 = conv2d(x, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv') 171 | elif pad_type == 'REFLECT': 172 | padded1 = padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='padding') 173 | conv1 = conv2d(padded1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv') 174 | normalized1 = norm(conv1, name='norm', _type=norm_, _ops=_ops, is_train=is_train) 175 | relu1 = tf.nn.relu(normalized1) 176 | 177 | # 3x3 Conv-Batch S1 178 | with tf.variable_scope('layer2'): 179 | if pad_type is None: 180 | conv2 = conv2d(relu1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv') 181 | elif pad_type == 'REFLECT': 182 | padded2 = padding2d(relu1, p_h=1, p_w=1, pad_type='REFLECT', name='padding') 183 | conv2 = conv2d(padded2, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv') 184 | normalized2 = norm(conv2, name='norm', _type=norm_, _ops=_ops, is_train=is_train) 185 | 186 | # sum layer1 and layer2 187 | output = x + normalized2 188 | return output 189 | 190 | 191 | def identity(x, name='identity', is_print=False): 192 | output = tf.identity(x, name=name) 193 | if is_print: 194 | print_activations(output) 195 | 196 | return output 197 | 198 | 199 | def max_pool_2x2(x, name='max_pool'): 200 | with tf.name_scope(name): 201 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 202 | 203 | 204 | def sigmoid(x, name='sigmoid', is_print=False): 205 | output = tf.nn.sigmoid(x, name=name) 206 | if is_print: 207 | print_activations(output) 208 | 209 | return output 210 | 211 | 212 | def tanh(x, name='tanh', is_print=False): 213 | output = tf.nn.tanh(x, name=name) 214 | if is_print: 215 | print_activations(output) 216 | 217 | return output 218 | 219 | 220 | def relu(x, name='relu', is_print=False): 221 | output = tf.nn.relu(x, name=name) 222 | if is_print: 223 | print_activations(output) 224 | 225 | return output 226 | 227 | 228 | def lrelu(x, leak=0.2, name='lrelu', is_print=False): 229 | output = tf.maximum(x, leak*x, name=name) 230 | if is_print: 231 | print_activations(output) 232 | 233 | return output 234 | 235 | 236 | def xavier_init(in_dim): 237 | # print('in_dim: ', in_dim) 238 | xavier_stddev = 1. / tf.sqrt(in_dim / 2.) 239 | return xavier_stddev 240 | 241 | 242 | def print_activations(t): 243 | print(t.op.name, ' ', t.get_shape().as_list()) 244 | 245 | 246 | def show_all_variables(): 247 | model_vars = tf.trainable_variables() 248 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 249 | 250 | 251 | def batch_convert2int(images): 252 | # images: 4D float tensor (batch_size, image_size, image_size, depth) 253 | return tf.map_fn(convert2int, images, dtype=tf.uint8) 254 | 255 | 256 | def convert2int(image): 257 | # transform from float tensor ([-1.,1.]) to int image ([0,255]) 258 | return tf.image.convert_image_dtype((image + 1.0) / 2.0, tf.uint8) 259 | -------------------------------------------------------------------------------- /src/jupyter/utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Python Utils Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import collections 9 | import matplotlib 10 | matplotlib.use('Agg') 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | def print_model_setting(locals_): 15 | print("Uppercase local vars:") 16 | 17 | all_vars = [(k, v) for (k, v) in locals_.items() if ( 18 | k.isupper() and k != 'T' and k != 'SETTINGS' and k != 'ALL_SETTINGS')] 19 | all_vars = sorted(all_vars, key=lambda x: x[0]) 20 | 21 | for var_name, var_value in all_vars: 22 | print("\t{}: {}".format(var_name, var_value)) 23 | 24 | 25 | _since_beginning = collections.defaultdict(lambda: {}) 26 | _since_last_flush = collections.defaultdict(lambda: {}) 27 | _iter = [0] 28 | 29 | 30 | def tick(): 31 | _iter[0] += 1 32 | 33 | 34 | def plot(name, value): 35 | _since_last_flush[name][_iter[0]] = value 36 | 37 | 38 | def flush(save_folder): 39 | prints = [] 40 | 41 | for name, vals in _since_last_flush.items(): 42 | sum_ = 0 43 | keys = vals.keys() 44 | values = vals.values() 45 | num_keys = len(list(keys)) 46 | for val in values: 47 | sum_ += val 48 | 49 | prints.append("{}\t{}".format(name, sum_/num_keys)) 50 | _since_beginning[name].update(vals) 51 | 52 | x_vals = _since_beginning[name].keys() 53 | y_vals = [_since_beginning[name][x] for x in x_vals] 54 | 55 | plt.clf() 56 | plt.plot(x_vals, y_vals) 57 | plt.xlabel('iteration') 58 | plt.ylabel(name) 59 | plt.savefig(os.path.join(save_folder, name.replace(' ', '_')+'.jpg')) 60 | 61 | print("iter {}\t{}".format(_iter[0], "\t".join(prints))) 62 | _since_last_flush.clear() 63 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow WGAN-GP Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import tensorflow as tf 9 | from solver import Solver 10 | 11 | FLAGS = tf.flags.FLAGS 12 | 13 | tf.flags.DEFINE_string('gpu_index', '0', 'gpu index if you have multiple gpus, default: 0') 14 | tf.flags.DEFINE_integer('batch_size', 64, 'batch size, default: 64') 15 | tf.flags.DEFINE_string('dataset', 'mnist', 'dataset name from [mnist, cifar10, imagenet64], default: mnist') 16 | 17 | tf.flags.DEFINE_bool('is_train', True, 'training or inference mode, default: True') 18 | tf.flags.DEFINE_float('learning_rate', 1e-4, 'initial learning rate for Adam, default: 0.0002') 19 | tf.flags.DEFINE_integer('num_critic', 5, 'the number of iterations of the critic per generator iteration, default: 5') 20 | tf.flags.DEFINE_integer('z_dim', 128, 'dimension of z vector, default: 128') 21 | tf.flags.DEFINE_float('lambda_', 10., 'gradient penalty lambda hyperparameter, default: 10.') 22 | tf.flags.DEFINE_float('beta1', 0.5, 'beta1 momentum term of Adam, default: 0.5') 23 | tf.flags.DEFINE_float('beta2', 0.9, 'beta2 momentum term of Adam, default: 0.9') 24 | 25 | tf.flags.DEFINE_integer('iters', 200000, 'number of iterations, default: 200000') 26 | tf.flags.DEFINE_integer('print_freq', 100, 'print frequency for loss, default: 100') 27 | tf.flags.DEFINE_integer('save_freq', 10000, 'save frequency for model, default: 10000') 28 | tf.flags.DEFINE_integer('sample_freq', 500, 'sample frequency for saving image, default: 500') 29 | tf.flags.DEFINE_integer('inception_freq', 1000, 'calculation frequence of inception score, default: 1000') 30 | tf.flags.DEFINE_integer('sample_batch', 64, 'number of sampling images for check generator quality, default: 64') 31 | tf.flags.DEFINE_string('load_model', None, 'folder of saved model taht you wish to continue training ' 32 | '(e.g. 20181017-1430), default: None') 33 | 34 | 35 | def main(_): 36 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_index 37 | 38 | solver = Solver(FLAGS) 39 | if FLAGS.is_train: 40 | solver.train() 41 | if not FLAGS.is_train: 42 | solver.test() 43 | 44 | 45 | if __name__ == '__main__': 46 | tf.app.run() 47 | -------------------------------------------------------------------------------- /src/plot.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Python Plot Function 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # From https://github.com/igul222/improved_wgan_training/blob/master/tflib/plot.py 5 | # Code from igul222 6 | # --------------------------------------------------------- 7 | import os 8 | import matplotlib as mpl 9 | mpl.use('TkAgg') # or whatever other backend that you want to solve Segmentation fault (core dumped) 10 | import matplotlib.pyplot as plt 11 | import collections 12 | 13 | 14 | _since_beginning = collections.defaultdict(lambda: {}) 15 | _since_last_flush = collections.defaultdict(lambda: {}) 16 | _iter = [0] 17 | 18 | 19 | def tick(): 20 | _iter[0] += 1 21 | 22 | 23 | def plot(name, value): 24 | _since_last_flush[name][_iter[0]] = value 25 | 26 | 27 | def flush(save_folder): 28 | prints = [] 29 | for name, vals in _since_last_flush.items(): 30 | sum_ = 0 31 | keys = vals.keys() 32 | values = vals.values() 33 | num_keys = len(list(keys)) 34 | for val in values: 35 | sum_ += val 36 | 37 | prints.append("{}\t{}".format(name, sum_/num_keys)) 38 | _since_beginning[name].update(vals) 39 | 40 | x_vals = _since_beginning[name].keys() 41 | y_vals = [_since_beginning[name][x] for x in x_vals] 42 | 43 | plt.clf() 44 | plt.plot(x_vals, y_vals) 45 | plt.xlabel('iteration') 46 | plt.ylabel(name) 47 | plt.savefig(os.path.join(save_folder, name.replace(' ', '_')+'.jpg')) 48 | 49 | # print("iter {}\t{}".format(_iter[0], "\t".join(prints))) 50 | _since_last_flush.clear() 51 | -------------------------------------------------------------------------------- /src/solver.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow WGAN-GP Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import logging 9 | import numpy as np 10 | import tensorflow as tf 11 | from datetime import datetime 12 | 13 | # noinspection PyPep8Naming 14 | import plot as plot 15 | from dataset_ import Dataset 16 | from wgan_gp import WGAN_GP 17 | from inception_score import get_inception_score 18 | 19 | logger = logging.getLogger(__name__) # logger 20 | logger.setLevel(logging.INFO) 21 | 22 | 23 | class Solver(object): 24 | def __init__(self, flags): 25 | run_config = tf.ConfigProto() 26 | run_config.gpu_options.allow_growth = True 27 | self.sess = tf.Session(config=run_config) 28 | 29 | self.flags = flags 30 | self.iter_time = 0 31 | self.num_examples_IS = 1000 32 | self._make_folders() 33 | self._init_logger() 34 | 35 | self.dataset = Dataset(self.sess, self.flags, self.flags.dataset, log_path=self.log_out_dir) 36 | self.model = WGAN_GP(self.sess, self.flags, self.dataset, log_path=self.log_out_dir) 37 | 38 | self.saver = tf.train.Saver() 39 | self.sess.run([tf.global_variables_initializer()]) 40 | 41 | # tf_utils.show_all_variables() 42 | 43 | def _make_folders(self): 44 | if self.flags.is_train: # train stage 45 | if self.flags.load_model is None: 46 | cur_time = datetime.now().strftime("%Y%m%d-%H%M") 47 | self.model_out_dir = "{}/model/{}".format(self.flags.dataset, cur_time) 48 | if not os.path.isdir(self.model_out_dir): 49 | os.makedirs(self.model_out_dir) 50 | else: 51 | cur_time = self.flags.load_model 52 | self.model_out_dir = "{}/model/{}".format(self.flags.dataset, cur_time) 53 | 54 | self.sample_out_dir = "{}/sample/{}".format(self.flags.dataset, cur_time) 55 | if not os.path.isdir(self.sample_out_dir): 56 | os.makedirs(self.sample_out_dir) 57 | 58 | self.log_out_dir = "{}/logs/{}".format(self.flags.dataset, cur_time) 59 | self.train_writer = tf.summary.FileWriter("{}/logs/{}".format(self.flags.dataset, cur_time), 60 | graph_def=self.sess.graph_def) 61 | 62 | elif not self.flags.is_train: # test stage 63 | self.model_out_dir = "{}/model/{}".format(self.flags.dataset, self.flags.load_model) 64 | self.test_out_dir = "{}/test/{}".format(self.flags.dataset, self.flags.load_model) 65 | self.log_out_dir = "{}/logs/{}".format(self.flags.dataset, self.flags.load_model) 66 | 67 | if not os.path.isdir(self.test_out_dir): 68 | os.makedirs(self.test_out_dir) 69 | 70 | def _init_logger(self): 71 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s') 72 | # file handler 73 | file_handler = logging.FileHandler(os.path.join(self.log_out_dir, 'solver.log')) 74 | file_handler.setFormatter(formatter) 75 | file_handler.setLevel(logging.INFO) 76 | # stream handler 77 | stream_handler = logging.StreamHandler() 78 | stream_handler.setFormatter(formatter) 79 | # add handlers 80 | logger.addHandler(file_handler) 81 | logger.addHandler(stream_handler) 82 | 83 | if self.flags.is_train: 84 | logger.info('gpu_index: {}'.format(self.flags.gpu_index)) 85 | logger.info('batch_size: {}'.format(self.flags.batch_size)) 86 | logger.info('dataset: {}'.format(self.flags.dataset)) 87 | 88 | logger.info('is_train: {}'.format(self.flags.is_train)) 89 | logger.info('learning_rate: {}'.format(self.flags.learning_rate)) 90 | logger.info('num_critic: {}'.format(self.flags.num_critic)) 91 | logger.info('z_dim: {}'.format(self.flags.z_dim)) 92 | logger.info('lambda_: {}'.format(self.flags.lambda_)) 93 | logger.info('beta1: {}'.format(self.flags.beta1)) 94 | logger.info('beta2: {}'.format(self.flags.beta2)) 95 | 96 | logger.info('iters: {}'.format(self.flags.iters)) 97 | logger.info('print_freq: {}'.format(self.flags.print_freq)) 98 | logger.info('save_freq: {}'.format(self.flags.save_freq)) 99 | logger.info('sample_freq: {}'.format(self.flags.sample_freq)) 100 | logger.info('inception_freq: {}'.format(self.flags.inception_freq)) 101 | logger.info('sample_batch: {}'.format(self.flags.sample_batch)) 102 | logger.info('load_model: {}'.format(self.flags.load_model)) 103 | 104 | def train(self): 105 | # load initialized checkpoint that provided 106 | if self.flags.load_model is not None: 107 | if self.load_model(): 108 | logger.info(' [*] Load SUCCESS!\n') 109 | else: 110 | logger.info(' [!] Load Failed...\n') 111 | 112 | # for iter_time in range(self.flags.iters): 113 | while self.iter_time < self.flags.iters: 114 | # sampling images and save them 115 | self.sample(self.iter_time) 116 | 117 | # train_step 118 | loss, summary = self.model.train_step() 119 | self.model.print_info(loss, self.iter_time) 120 | self.train_writer.add_summary(summary, self.iter_time) 121 | self.train_writer.flush() 122 | 123 | if self.flags.dataset == 'cifar10': 124 | self.get_inception_score(self.iter_time) # calculate inception score 125 | 126 | # save model 127 | self.save_model(self.iter_time) 128 | self.iter_time += 1 129 | 130 | self.save_model(self.flags.iters) 131 | 132 | def test(self): 133 | if self.load_model(): 134 | logger.info(' [*] Load SUCCESS!') 135 | else: 136 | logger.info(' [!] Load Failed...') 137 | 138 | num_iters = 20 139 | for iter_time in range(num_iters): 140 | print('iter_time: {}'.format(iter_time)) 141 | 142 | imgs = self.model.test_step() 143 | self.model.plots(imgs, iter_time, self.test_out_dir) 144 | 145 | def get_inception_score(self, iter_time): 146 | if np.mod(iter_time, self.flags.inception_freq) == 0: 147 | sample_size = 100 148 | all_samples = [] 149 | for _ in range(int(self.num_examples_IS/sample_size)): 150 | imgs = self.model.sample_imgs(sample_size=sample_size) 151 | all_samples.append(imgs[0]) 152 | 153 | all_samples = np.concatenate(all_samples, axis=0) 154 | all_samples = ((all_samples + 1.) * 255. / 2.).astype(np.uint8) 155 | 156 | mean_IS, std_IS = get_inception_score(list(all_samples), self.flags) 157 | # print('Inception score iter: {}, IS: {}'.format(self.iter_time, mean_IS)) 158 | 159 | plot.plot('inception score', mean_IS) 160 | plot.flush(self.log_out_dir) # write logs 161 | plot.tick() 162 | 163 | def sample(self, iter_time): 164 | if np.mod(iter_time, self.flags.sample_freq) == 0: 165 | imgs = self.model.sample_imgs(sample_size=self.flags.sample_batch) 166 | self.model.plots(imgs, iter_time, self.sample_out_dir) 167 | 168 | def save_model(self, iter_time): 169 | if np.mod(iter_time + 1, self.flags.save_freq) == 0: 170 | model_name = 'model' 171 | self.saver.save(self.sess, os.path.join(self.model_out_dir, model_name), global_step=iter_time) 172 | logger.info('[*] Model saved! Iter: {}'.format(iter_time)) 173 | 174 | def load_model(self): 175 | logger.info(' [*] Reading checkpoint...') 176 | 177 | checkpoint = tf.train.get_checkpoint_state(self.model_out_dir) 178 | if checkpoint and checkpoint.model_checkpoint_path: 179 | ckpt_name = os.path.basename(checkpoint.model_checkpoint_path) 180 | self.saver.restore(self.sess, os.path.join(self.model_out_dir, ckpt_name)) 181 | 182 | meta_graph_path = checkpoint.model_checkpoint_path + '.meta' 183 | self.iter_time = int(meta_graph_path.split('-')[-1].split('.')[0]) 184 | 185 | logger.info('[*] Load iter_time: {}'.format(self.iter_time)) 186 | return True 187 | else: 188 | return False 189 | -------------------------------------------------------------------------------- /src/tensorflow_utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow Utils Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import logging 9 | import functools 10 | import tensorflow as tf 11 | import tensorflow.contrib.slim as slim 12 | from tensorflow.python.training import moving_averages 13 | 14 | logger = logging.getLogger(__name__) # logger 15 | logger.setLevel(logging.INFO) 16 | 17 | 18 | def _init_logger(log_path): 19 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s') 20 | # file handler 21 | file_handler = logging.FileHandler(os.path.join(log_path, 'model.log')) 22 | file_handler.setFormatter(formatter) 23 | file_handler.setLevel(logging.INFO) 24 | # stream handler 25 | stream_handler = logging.StreamHandler() 26 | stream_handler.setFormatter(formatter) 27 | # add handlers 28 | logger.addHandler(file_handler) 29 | logger.addHandler(stream_handler) 30 | 31 | 32 | def padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='pad2d'): 33 | if pad_type == 'REFLECT': 34 | return tf.pad(x, [[0, 0], [p_h, p_h], [p_w, p_w], [0, 0]], 'REFLECT', name=name) 35 | 36 | 37 | def conv2d(x, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, padding='SAME', name='conv2d', is_print=True): 38 | with tf.variable_scope(name): 39 | w = tf.get_variable('w', [k_h, k_w, x.get_shape()[-1], output_dim], 40 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 41 | conv = tf.nn.conv2d(x, w, strides=[1, d_h, d_w, 1], padding=padding) 42 | 43 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 44 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 45 | conv = tf.nn.bias_add(conv, biases) 46 | 47 | if is_print: 48 | print_activations(conv) 49 | 50 | return conv 51 | 52 | 53 | def deconv2d(x, k, k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, padding_='SAME', output_size=None, 54 | name='deconv2d', with_w=False, is_print=True): 55 | with tf.variable_scope(name): 56 | input_shape = x.get_shape().as_list() 57 | 58 | # calculate output size 59 | h_output, w_output = None, None 60 | if not output_size: 61 | h_output, w_output = input_shape[1] * 2, input_shape[2] * 2 62 | # output_shape = [input_shape[0], h_output, w_output, k] # error when not define batch_size 63 | output_shape = [tf.shape(x)[0], h_output, w_output, k] 64 | 65 | # conv2d transpose 66 | w = tf.get_variable('w', [k_h, k_w, k, input_shape[3]], 67 | initializer=tf.random_normal_initializer(stddev=stddev)) 68 | deconv = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=[1, d_h, d_w, 1], 69 | padding=padding_) 70 | 71 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 72 | deconv = tf.nn.bias_add(deconv, biases) 73 | 74 | if is_print: 75 | print_activations(deconv) 76 | 77 | if with_w: 78 | return deconv, w, biases 79 | else: 80 | return deconv 81 | 82 | 83 | def upsampling2d(x, size=(2, 2), name='upsampling2d'): 84 | with tf.name_scope(name): 85 | shape = x.get_shape().as_list() 86 | return tf.image.resize_nearest_neighbor(x, size=(size[0] * shape[1], size[1] * shape[2])) 87 | 88 | 89 | def linear(x, output_size, bias_start=0.0, with_w=False, name='fc'): 90 | shape = x.get_shape().as_list() 91 | 92 | with tf.variable_scope(name): 93 | matrix = tf.get_variable(name="matrix", shape=[shape[1], output_size], 94 | dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) 95 | bias = tf.get_variable(name="bias", shape=[output_size], 96 | initializer=tf.constant_initializer(bias_start)) 97 | if with_w: 98 | return tf.matmul(x, matrix) + bias, matrix, bias 99 | else: 100 | return tf.matmul(x, matrix) + bias 101 | 102 | 103 | def norm(x, name, _type, _ops, is_train=True): 104 | if _type == 'batch': 105 | return batch_norm(x, name=name, _ops=_ops, is_train=is_train) 106 | elif _type == 'instance': 107 | return instance_norm(x, name=name) 108 | elif _type == 'layer': 109 | return layer_norm(x, name=name) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | def batch_norm(x, name, _ops, is_train=True): 115 | """Batch normalization.""" 116 | with tf.variable_scope(name): 117 | params_shape = [x.get_shape()[-1]] 118 | 119 | beta = tf.get_variable('beta', params_shape, tf.float32, 120 | initializer=tf.constant_initializer(0.0, tf.float32)) 121 | gamma = tf.get_variable('gamma', params_shape, tf.float32, 122 | initializer=tf.constant_initializer(1.0, tf.float32)) 123 | 124 | if is_train is True: 125 | mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments') 126 | 127 | moving_mean = tf.get_variable('moving_mean', params_shape, tf.float32, 128 | initializer=tf.constant_initializer(0.0, tf.float32), 129 | trainable=False) 130 | moving_variance = tf.get_variable('moving_variance', params_shape, tf.float32, 131 | initializer=tf.constant_initializer(1.0, tf.float32), 132 | trainable=False) 133 | 134 | _ops.append(moving_averages.assign_moving_average(moving_mean, mean, 0.9)) 135 | _ops.append(moving_averages.assign_moving_average(moving_variance, variance, 0.9)) 136 | else: 137 | mean = tf.get_variable('moving_mean', params_shape, tf.float32, 138 | initializer=tf.constant_initializer(0.0, tf.float32), trainable=False) 139 | variance = tf.get_variable('moving_variance', params_shape, tf.float32, trainable=False) 140 | 141 | # epsilon used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net. 142 | y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-5) 143 | y.set_shape(x.get_shape()) 144 | 145 | return y 146 | 147 | 148 | def instance_norm(x, name='instance_norm', mean=1.0, stddev=0.02, epsilon=1e-5): 149 | with tf.variable_scope(name): 150 | depth = x.get_shape()[3] 151 | scale = tf.get_variable( 152 | 'scale', [depth], tf.float32, 153 | initializer=tf.random_normal_initializer(mean=mean, stddev=stddev, dtype=tf.float32)) 154 | offset = tf.get_variable('offset', [depth], initializer=tf.constant_initializer(0.0)) 155 | 156 | # calcualte mean and variance as instance 157 | mean, variance = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 158 | 159 | # normalization 160 | inv = tf.rsqrt(variance + epsilon) 161 | normalized = (x - mean) * inv 162 | 163 | return scale * normalized + offset 164 | 165 | 166 | # TODO: I'm not sure is it a good implementation of layer normalization... 167 | def layer_norm(x, name='layer_norm'): 168 | with tf.variable_scope(name): 169 | norm_axes = [1, 2, 3] 170 | mean, var = tf.nn.moments(x, axes=norm_axes, keep_dims=True) 171 | 172 | # Assume the 'neurons' axis is the third of norm_axes. This is the case for fully-connected 173 | # and BHWC conv layers. 174 | n_neurons = x.get_shape().as_list()[norm_axes[2]] 175 | offset = tf.get_variable('offset', n_neurons, tf.float32, initializer=tf.constant_initializer(0.0, tf.float32)) 176 | scale = tf.get_variable('scale', n_neurons, tf.float32, initializer=tf.constant_initializer(1.0, tf.float32)) 177 | 178 | # Add broadcasting dims to offset and scale (e.g. BCHW conv data) 179 | offset = tf.reshape(offset, [1 for _ in range(len(norm_axes)-1)] + [-1]) 180 | scale = tf.reshape(scale, [1 for _ in range(len(norm_axes)-1)] + [-1]) 181 | 182 | result = tf.nn.batch_normalization(x, mean, var, offset, scale, 1e-5) 183 | 184 | return result 185 | 186 | 187 | def n_res_blocks(x, _ops=None, norm_='instance', is_train=True, num_blocks=6, is_print=False): 188 | output = None 189 | for idx in range(1, num_blocks+1): 190 | output = res_block(x, x.get_shape()[3], _ops=_ops, norm_=norm_, is_train=is_train, 191 | name='res{}'.format(idx)) 192 | x = output 193 | 194 | if is_print: 195 | print_activations(output) 196 | 197 | return output 198 | 199 | 200 | # norm(x, name, _type, _ops, is_train=True) 201 | def res_block(x, k, _ops=None, norm_='instance', is_train=True, pad_type=None, name=None): 202 | with tf.variable_scope(name): 203 | conv1, conv2 = None, None 204 | 205 | # 3x3 Conv-Batch-Relu S1 206 | with tf.variable_scope('layer1'): 207 | if pad_type is None: 208 | conv1 = conv2d(x, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv') 209 | elif pad_type == 'REFLECT': 210 | padded1 = padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='padding') 211 | conv1 = conv2d(padded1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv') 212 | normalized1 = norm(conv1, name='norm', _type=norm_, _ops=_ops, is_train=is_train) 213 | relu1 = tf.nn.relu(normalized1) 214 | 215 | # 3x3 Conv-Batch S1 216 | with tf.variable_scope('layer2'): 217 | if pad_type is None: 218 | conv2 = conv2d(relu1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv') 219 | elif pad_type == 'REFLECT': 220 | padded2 = padding2d(relu1, p_h=1, p_w=1, pad_type='REFLECT', name='padding') 221 | conv2 = conv2d(padded2, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv') 222 | normalized2 = norm(conv2, name='norm', _type=norm_, _ops=_ops, is_train=is_train) 223 | 224 | # sum layer1 and layer2 225 | output = x + normalized2 226 | return output 227 | 228 | 229 | def identity(x, name='identity', is_print=False): 230 | output = tf.identity(x, name=name) 231 | if is_print: 232 | print_activations(output) 233 | 234 | return output 235 | 236 | 237 | def avgPoolConv(x, output_dim, filter_size=3, stride=1, name='avgPoolConv', is_print=True): 238 | with tf.variable_scope(name): 239 | output = avg_pool_2x2(x) 240 | output = conv2d(output, output_dim=output_dim, k_h=filter_size, k_w=filter_size, d_h=stride, d_w=stride) 241 | if is_print: 242 | print_activations(output) 243 | 244 | return output 245 | 246 | 247 | def convAvgPool(x, output_dim, filter_size=3, stride=1, name='convAvgPool', is_print=True): 248 | with tf.variable_scope(name): 249 | output = conv2d(x, output_dim=output_dim, k_h=filter_size, k_w=filter_size, d_h=stride, d_w=stride) 250 | output = avg_pool_2x2(output) 251 | if is_print: 252 | print_activations(output) 253 | 254 | return output 255 | 256 | 257 | def max_pool_2x2(x, name='max_pool'): 258 | with tf.name_scope(name): 259 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 260 | 261 | 262 | def avg_pool_2x2(x, name='avg_pool'): 263 | with tf.name_scope(name): 264 | return tf.nn.avg_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 265 | 266 | 267 | def sigmoid(x, name='sigmoid', is_print=False): 268 | output = tf.nn.sigmoid(x, name=name) 269 | if is_print: 270 | print_activations(output) 271 | 272 | return output 273 | 274 | 275 | def tanh(x, name='tanh', is_print=False): 276 | output = tf.nn.tanh(x, name=name) 277 | if is_print: 278 | print_activations(output) 279 | 280 | return output 281 | 282 | 283 | def relu(x, name='relu', is_print=False): 284 | output = tf.nn.relu(x, name=name) 285 | if is_print: 286 | print_activations(output) 287 | 288 | return output 289 | 290 | 291 | def lrelu(x, leak=0.2, name='lrelu', is_print=False): 292 | output = tf.maximum(x, leak*x, name=name) 293 | if is_print: 294 | print_activations(output) 295 | 296 | return output 297 | 298 | 299 | def xavier_init(in_dim): 300 | # print('in_dim: ', in_dim) 301 | xavier_stddev = 1. / tf.sqrt(in_dim / 2.) 302 | return xavier_stddev 303 | 304 | 305 | def print_activations(t): 306 | # print(t.op.name, ' ', t.get_shape().as_list()) 307 | logger.info(t.op.name + '{}'.format(t.get_shape().as_list())) 308 | 309 | 310 | def show_all_variables(): 311 | model_vars = tf.trainable_variables() 312 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 313 | 314 | 315 | def batch_convert2int(images): 316 | # images: 4D float tensor (batch_size, image_size, image_size, depth) 317 | return tf.map_fn(convert2int, images, dtype=tf.uint8) 318 | 319 | 320 | def convert2int(image): 321 | # transform from float tensor ([-1.,1.]) to int image ([0,255]) 322 | return tf.image.convert_image_dtype((image + 1.0) / 2.0, tf.uint8) 323 | 324 | 325 | def res_block_v2(x, k, filter_size, _ops=None, norm_='instance', is_train=True, resample=None, name=None): 326 | with tf.variable_scope(name): 327 | if resample == 'down': 328 | conv_shortcut = functools.partial(avgPoolConv, output_dim=k, filter_size=1) 329 | conv_1 = functools.partial(conv2d, output_dim=k, k_h=filter_size, k_w=filter_size, d_h=1, d_w=1) 330 | conv_2 = functools.partial(convAvgPool, output_dim=k) 331 | elif resample == 'up': 332 | conv_shortcut = functools.partial(deconv2d, k=k) 333 | conv_1 = functools.partial(deconv2d, k=k, k_h=filter_size, k_w=filter_size) 334 | conv_2 = functools.partial(conv2d, output_dim=k, k_h=filter_size, k_w=filter_size, d_h=1, d_w=1) 335 | elif resample is None: 336 | conv_shortcut = functools.partial(conv2d, output_dim=k, k_h=filter_size, k_w=filter_size, d_h=1, d_w=1) 337 | conv_1 = functools.partial(conv2d, output_dim=k, k_h=filter_size, k_w=filter_size, d_h=1, d_w=1) 338 | conv_2 = functools.partial(conv2d, output_dim=k, k_h=filter_size, k_w=filter_size, d_h=1, d_w=1) 339 | else: 340 | raise Exception('invalid resample value') 341 | 342 | if (k == x.get_shape().as_list()[3]) and (resample is None): 343 | shortcut = x # Identity skip-connection 344 | else: 345 | shortcut = conv_shortcut(x, name='shortcut') 346 | 347 | output = x 348 | output = norm(output, _type=norm_, _ops=_ops, is_train=is_train, name='norm1') 349 | output = relu(output, name='relu1') 350 | output = conv_1(output, name='conv1') 351 | output = norm(output, _type=norm_, _ops=_ops, is_train=is_train, name='norm2') 352 | output = relu(output, name='relu2') 353 | output = conv_2(output, name='conv2') 354 | 355 | return shortcut + output 356 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Python Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import sys 9 | import random 10 | import numpy as np 11 | import matplotlib as mpl 12 | import scipy.misc 13 | mpl.use('TkAgg') # or whatever other backend that you want to solve Segmentation fault (core dumped) 14 | import matplotlib.pyplot as plt 15 | import matplotlib.gridspec as gridspec 16 | from PIL import Image 17 | 18 | 19 | class ImagePool(object): 20 | def __init__(self, pool_size=50): 21 | self.pool_size = pool_size 22 | self.imgs = [] 23 | 24 | def query(self, img): 25 | if self.pool_size == 0: 26 | return img 27 | 28 | if len(self.imgs) < self.pool_size: 29 | self.imgs.append(img) 30 | return img 31 | else: 32 | if random.random() > 0.5: 33 | # use old image 34 | random_id = random.randrange(0, self.pool_size) 35 | tmp_img = self.imgs[random_id].copy() 36 | self.imgs[random_id] = img.copy() 37 | return tmp_img 38 | else: 39 | return img 40 | 41 | 42 | def all_files_under(path, extension=None, append_path=True, sort=True): 43 | if append_path: 44 | if extension is None: 45 | filenames = [os.path.join(path, fname) for fname in os.listdir(path)] 46 | else: 47 | filenames = [os.path.join(path, fname) 48 | for fname in os.listdir(path) if fname.endswith(extension)] 49 | else: 50 | if extension is None: 51 | filenames = [os.path.basename(fname) for fname in os.listdir(path)] 52 | else: 53 | filenames = [os.path.basename(fname) 54 | for fname in os.listdir(path) if fname.endswith(extension)] 55 | 56 | if sort: 57 | filenames = sorted(filenames) 58 | 59 | return filenames 60 | 61 | 62 | def imagefiles2arrs(filenames): 63 | img_shape = image_shape(filenames[0]) 64 | images_arr = None 65 | 66 | if len(img_shape) == 3: # color image 67 | images_arr = np.zeros((len(filenames), img_shape[0], img_shape[1], img_shape[2]), dtype=np.float32) 68 | elif len(img_shape) == 2: # gray scale image 69 | images_arr = np.zeros((len(filenames), img_shape[0], img_shape[1]), dtype=np.float32) 70 | 71 | for file_index in range(len(filenames)): 72 | img = Image.open(filenames[file_index]) 73 | images_arr[file_index] = np.asarray(img).astype(np.float32) 74 | 75 | return images_arr 76 | 77 | 78 | def image_shape(filename): 79 | img = Image.open(filename, mode="r") 80 | img_arr = np.asarray(img) 81 | img_shape = img_arr.shape 82 | return img_shape 83 | 84 | 85 | def print_metrics(itr, kargs): 86 | print("*** Iteration {} ====> ".format(itr)) 87 | for name, value in kargs.items(): 88 | print("{} : {}, ".format(name, value)) 89 | print("") 90 | sys.stdout.flush() 91 | 92 | 93 | def transform(img): 94 | return img / 127.5 - 1.0 95 | 96 | 97 | def inverse_transform(img): 98 | return (img + 1.) / 2. 99 | 100 | 101 | def preprocess_pair(img_a, img_b, load_size=286, fine_size=256, flip=True, is_test=False): 102 | if is_test: 103 | img_a = scipy.misc.imresize(img_a, [fine_size, fine_size]) 104 | img_b = scipy.misc.imresize(img_b, [fine_size, fine_size]) 105 | else: 106 | img_a = scipy.misc.imresize(img_a, [load_size, load_size]) 107 | img_b = scipy.misc.imresize(img_b, [load_size, load_size]) 108 | 109 | h1 = int(np.ceil(np.random.uniform(1e-2, load_size - fine_size))) 110 | w1 = int(np.ceil(np.random.uniform(1e-2, load_size - fine_size))) 111 | img_a = img_a[h1:h1 + fine_size, w1:w1 + fine_size] 112 | img_b = img_b[h1:h1 + fine_size, w1:w1 + fine_size] 113 | 114 | if flip and np.random.random() > 0.5: 115 | img_a = np.fliplr(img_a) 116 | img_b = np.fliplr(img_b) 117 | 118 | return img_a, img_b 119 | 120 | 121 | def imread(path, is_gray_scale=False, img_size=None): 122 | if is_gray_scale: 123 | img = scipy.misc.imread(path, flatten=True).astype(np.float) 124 | else: 125 | img = scipy.misc.imread(path, mode='RGB').astype(np.float) 126 | 127 | if not (img.ndim == 3 and img.shape[2] == 3): 128 | img = np.dstack((img, img, img)) 129 | 130 | if img_size is not None: 131 | img = scipy.misc.imresize(img, img_size) 132 | 133 | return img 134 | 135 | 136 | def load_image(image_path, which_direction=0, is_gray_scale=True, img_size=(256, 256, 1)): 137 | input_img = imread(image_path, is_gray_scale=is_gray_scale, img_size=img_size) 138 | w_pair = int(input_img.shape[1]) 139 | w_single = int(w_pair / 2) 140 | 141 | if which_direction == 0: # A to B 142 | img_a = input_img[:, 0:w_single] 143 | img_b = input_img[:, w_single:w_pair] 144 | else: # B to A 145 | img_a = input_img[:, w_single:w_pair] 146 | img_b = input_img[:, 0:w_single] 147 | 148 | return img_a, img_b 149 | 150 | 151 | def load_data(image_path, is_gray_scale=False): 152 | img = imread(path=image_path, is_gray_scale=is_gray_scale) 153 | img_trans = transform(img) # from [0, 255] to [-1., 1.] 154 | 155 | if is_gray_scale and (img_trans.ndim == 2): 156 | img_trans = np.expand_dims(img_trans, axis=2) 157 | 158 | return img_trans 159 | 160 | 161 | def plots(imgs, iter_time, save_file, grid_cols, grid_rows, sample_batch, name=None): 162 | # parameters for plot size 163 | scale, margin = 0.02, 0.02 164 | 165 | # save more bigger image 166 | img_h, img_w, img_c = imgs.shape[1:] 167 | fig = plt.figure(figsize=(img_w * grid_cols * scale, img_h * grid_rows * scale)) # (column, row) 168 | gs = gridspec.GridSpec(grid_rows, grid_cols) # (row, column) 169 | gs.update(wspace=margin, hspace=margin) 170 | 171 | for img_idx in range(sample_batch): 172 | ax = plt.subplot(gs[img_idx]) 173 | plt.axis('off') 174 | ax.set_xticklabels([]) 175 | ax.set_yticklabels([]) 176 | ax.set_aspect('equal') 177 | 178 | if imgs[img_idx].shape[2] == 1: # gray scale 179 | plt.imshow((imgs[img_idx]).reshape(img_h, img_w), cmap='Greys_r') 180 | else: 181 | plt.imshow((imgs[img_idx]).reshape(img_h, img_w, img_c), cmap='Greys_r') 182 | 183 | plt.savefig(save_file + '/{}_{}.png'.format(str(iter_time), name), bbox_inches='tight') 184 | plt.close(fig) 185 | -------------------------------------------------------------------------------- /src/wgan_gp.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # TensorFlow WGAN-GP Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import logging 8 | import collections 9 | import numpy as np 10 | import tensorflow as tf 11 | from tensorflow.contrib.layers import flatten 12 | import matplotlib as mpl 13 | mpl.use('TkAgg') # or whatever other backend that you want to solve Segmentation fault (core dumped) 14 | import matplotlib.pyplot as plt 15 | import matplotlib.gridspec as gridspec 16 | 17 | # noinspection PyPep8Naming 18 | import tensorflow_utils as tf_utils 19 | import utils as utils 20 | 21 | logger = logging.getLogger(__name__) # logger 22 | logger.setLevel(logging.INFO) 23 | 24 | 25 | # noinspection PyPep8Naming 26 | class WGAN_GP(object): 27 | def __init__(self, sess, flags, dataset, log_path=None): 28 | self.sess = sess 29 | self.flags = flags 30 | self.dataset = dataset 31 | self.image_size = dataset.image_size 32 | self.log_path = log_path 33 | 34 | if self.flags.dataset == 'mnist': 35 | self.gen_c = [4*4*256, 128, 64] 36 | self.dis_c = [64, 128, 256] 37 | elif self.flags.dataset == 'cifar10': 38 | self.gen_c = [4*4*4*128, 256, 128] 39 | self.dis_c = [128, 256, 512] 40 | elif self.flags.dataset == 'imagenet64': 41 | self.gen_c = [4*4*8*64, 512, 256, 128, 64] 42 | self.dis_c = [64, 128, 256, 512, 512] 43 | else: 44 | raise NotImplementedError 45 | 46 | self.gen_train_ops, self.dis_train_ops = [], [] 47 | 48 | self._init_logger() # init logger 49 | self._build_net() # init graph 50 | self._tensorboard() # init tensorboard 51 | logger.info("Initialized WGAN-GP SUCCESS!") 52 | 53 | def _init_logger(self): 54 | if self.flags.is_train: 55 | tf_utils._init_logger(self.log_path) 56 | 57 | def _build_net(self): 58 | self.Y = tf.placeholder(tf.float32, shape=[None, *self.image_size], name='real_data') 59 | self.z = tf.placeholder(tf.float32, shape=[None, self.flags.z_dim], name='latent_vector') 60 | 61 | if self.flags.dataset == 'imagenet64': 62 | self.generator = self.resnetGenerator 63 | self.discriminator = self.resnetDiscriminator 64 | else: 65 | self.generator = self.basicGenerator 66 | self.discriminator = self.basicDiscriminator 67 | 68 | self.g_samples = self.generator(self.z) 69 | _, d_logit_real = self.discriminator(self.Y) 70 | _, d_logit_fake = self.discriminator(self.g_samples, is_reuse=True) 71 | 72 | # discriminator loss 73 | self.wgan_d_loss = tf.reduce_mean(d_logit_fake) - tf.reduce_mean(d_logit_real) 74 | # generator loss 75 | self.g_loss = -tf.reduce_mean(d_logit_fake) 76 | 77 | d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='d_') 78 | g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g_') 79 | 80 | # gradient penalty 81 | self.gp_loss = self.gradient_penalty() 82 | self.d_loss = self.wgan_d_loss + self.flags.lambda_ * self.gp_loss 83 | 84 | # Optimizers for generator and discriminator 85 | self.gen_optim = tf.train.AdamOptimizer( 86 | learning_rate=self.flags.learning_rate, beta1=0.5, beta2=0.9).minimize(self.g_loss, var_list=g_vars) 87 | self.dis_optim = tf.train.AdamOptimizer( 88 | learning_rate=self.flags.learning_rate, beta1=0.5, beta2=0.9).minimize(self.d_loss, var_list=d_vars) 89 | 90 | def gradient_penalty(self): 91 | alpha = tf.random_uniform(shape=[self.flags.batch_size, 1, 1, 1], minval=0., maxval=1.) 92 | differences = self.g_samples - self.Y 93 | interpolates = self.Y + (alpha * differences) 94 | gradients = tf.gradients(self.discriminator(interpolates, is_reuse=True), [interpolates])[0] 95 | slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1, 2, 3])) 96 | gradient_penalty = tf.reduce_mean((slopes-1.)**2) 97 | 98 | return gradient_penalty 99 | 100 | def _tensorboard(self): 101 | tf.summary.scalar('loss/negative_wgan_d_loss', -self.wgan_d_loss) 102 | tf.summary.scalar('loss/gp_loss', self.gp_loss) 103 | tf.summary.scalar('loss/negative_d_loss', -self.d_loss) # negative critic loss 104 | tf.summary.scalar('loss/g_loss', self.g_loss) 105 | 106 | self.summary_op = tf.summary.merge_all() 107 | 108 | def basicGenerator(self, data, name='g_'): 109 | with tf.variable_scope(name): 110 | data_flatten = flatten(data) 111 | tf_utils.print_activations(data_flatten) 112 | 113 | # from (N, 128) to (N, 4, 4, 256) 114 | h0_linear = tf_utils.linear(data_flatten, self.gen_c[0], name='h0_linear') 115 | if self.flags.dataset == 'cifar10': 116 | h0_linear = tf.reshape(h0_linear, [tf.shape(h0_linear)[0], 4, 4, int(self.gen_c[0] / (4 * 4))]) 117 | h0_linear = tf_utils.norm(h0_linear, _type='batch', _ops=self.gen_train_ops, name='h0_norm') 118 | h0_relu = tf.nn.relu(h0_linear, name='h0_relu') 119 | h0_reshape = tf.reshape(h0_relu, [tf.shape(h0_relu)[0], 4, 4, int(self.gen_c[0]/(4*4))]) 120 | 121 | # from (N, 4, 4, 256) to (N, 8, 8, 128) 122 | h1_deconv = tf_utils.deconv2d(h0_reshape, self.gen_c[1], k_h=5, k_w=5, name='h1_deconv2d') 123 | if self.flags.dataset == 'cifar10': 124 | h1_deconv = tf_utils.norm(h1_deconv, _type='batch', _ops=self.gen_train_ops, name='h1_norm') 125 | h1_relu = tf.nn.relu(h1_deconv, name='h1_relu') 126 | 127 | # from (N, 8, 8, 128) to (N, 16, 16, 64) 128 | h2_deconv = tf_utils.deconv2d(h1_relu, self.gen_c[2], k_h=5, k_w=5, name='h2_deconv2d') 129 | if self.flags.dataset == 'cifar10': 130 | h2_deconv = tf_utils.norm(h2_deconv, _type='batch', _ops=self.gen_train_ops, name='h2_norm') 131 | h2_relu = tf.nn.relu(h2_deconv, name='h2_relu') 132 | 133 | # from (N, 16, 16, 64) to (N, 32, 32, 1) 134 | output = tf_utils.deconv2d(h2_relu, self.image_size[2], k_h=5, k_w=5, name='h3_deconv2d') 135 | 136 | return tf_utils.tanh(output) 137 | 138 | def basicDiscriminator(self, data, name='d_', is_reuse=False): 139 | with tf.variable_scope(name) as scope: 140 | if is_reuse is True: 141 | scope.reuse_variables() 142 | tf_utils.print_activations(data) 143 | 144 | # from (N, 32, 32, 1) to (N, 16, 16, 64) 145 | h0_conv = tf_utils.conv2d(data, self.dis_c[0], k_h=5, k_w=5, name='h0_conv2d') 146 | h0_lrelu = tf_utils.lrelu(h0_conv, name='h0_lrelu') 147 | 148 | # from (N, 16, 16, 64) to (N, 8, 8, 128) 149 | h1_conv = tf_utils.conv2d(h0_lrelu, self.dis_c[1], k_h=5, k_w=5, name='h1_conv2d') 150 | h1_lrelu = tf_utils.lrelu(h1_conv, name='h1_lrelu') 151 | 152 | # from (N, 8, 8, 128) to (N, 4, 4, 256) 153 | h2_conv = tf_utils.conv2d(h1_lrelu, self.dis_c[2], k_h=5, k_w=5, name='h2_conv2d') 154 | h2_lrelu = tf_utils.lrelu(h2_conv, name='h2_lrelu') 155 | 156 | # from (N, 4, 4, 256) to (N, 4096) and to (N, 1) 157 | h2_flatten = flatten(h2_lrelu) 158 | h3_linear = tf_utils.linear(h2_flatten, 1, name='h3_linear') 159 | 160 | return tf.nn.sigmoid(h3_linear), h3_linear 161 | 162 | def resnetGenerator(self, data, name='g_'): 163 | with tf.variable_scope(name): 164 | data_flatten = flatten(data) 165 | tf_utils.print_activations(data_flatten) 166 | 167 | # from (N, 128) to (N, 4, 4, 512) 168 | h0_linear = tf_utils.linear(data_flatten, self.gen_c[0], name='h0_linear') 169 | h0_reshape = tf.reshape(h0_linear, [tf.shape(h0_linear)[0], 4, 4, int(self.gen_c[0]/(4*4))]) 170 | 171 | # (N, 8, 8, 512) 172 | resblock_1 = tf_utils.res_block_v2(h0_reshape, self.gen_c[1], filter_size=3, _ops=self.gen_train_ops, 173 | norm_='batch', resample='up', name='res_block_1') 174 | # (N, 16, 16, 256) 175 | resblock_2 = tf_utils.res_block_v2(resblock_1, self.gen_c[2], filter_size=3, _ops=self.gen_train_ops, 176 | norm_='batch', resample='up', name='res_block_2') 177 | # (N, 32, 32, 128) 178 | resblock_3 = tf_utils.res_block_v2(resblock_2, self.gen_c[3], filter_size=3, _ops=self.gen_train_ops, 179 | norm_='batch', resample='up', name='res_block_3') 180 | # (N, 64, 64, 64) 181 | resblock_4 = tf_utils.res_block_v2(resblock_3, self.gen_c[4], filter_size=3, _ops=self.gen_train_ops, 182 | norm_='batch', resample='up', name='res_block_4') 183 | 184 | norm_5 = tf_utils.norm(resblock_4, _type='batch', _ops=self.gen_train_ops, name='norm_5') 185 | relu_5 = tf_utils.relu(norm_5, name='relu_5') 186 | # (N, 64, 64, 3) 187 | output = tf_utils.conv2d(relu_5, output_dim=self.image_size[2], k_w=3, k_h=3, d_h=1, d_w=1, name='output') 188 | 189 | return tf_utils.tanh(output) 190 | 191 | def resnetDiscriminator(self, data, name='d_', is_reuse=False): 192 | with tf.variable_scope(name) as scope: 193 | if is_reuse is True: 194 | scope.reuse_variables() 195 | tf_utils.print_activations(data) 196 | 197 | # (N, 64, 64, 64) 198 | conv_0 = tf_utils.conv2d(data, output_dim=self.dis_c[0], k_h=3, k_w=3, d_h=1, d_w=1, name='conv_0') 199 | # (N, 32, 32, 128) 200 | resblock_1 = tf_utils.res_block_v2(conv_0, self.dis_c[1], filter_size=3, _ops=self.dis_train_ops, 201 | norm_='layer', resample='down', name='res_block_1') 202 | # (N, 16, 16, 256) 203 | resblock_2 = tf_utils.res_block_v2(resblock_1, self.dis_c[2], filter_size=3, _ops=self.dis_train_ops, 204 | norm_='layer', resample='down', name='res_block_2') 205 | # (N, 8, 8, 512) 206 | resblock_3 = tf_utils.res_block_v2(resblock_2, self.dis_c[3], filter_size=3, _ops=self.dis_train_ops, 207 | norm_='layer', resample='down', name='res_block_3') 208 | # (N, 4, 4, 512) 209 | resblock_4 = tf_utils.res_block_v2(resblock_3, self.dis_c[4], filter_size=3, _ops=self.dis_train_ops, 210 | norm_='layer', resample='down', name='res_block_4') 211 | # (N, 4*4*512) 212 | flatten_5 = flatten(resblock_4) 213 | output = tf_utils.linear(flatten_5, 1, name='output') 214 | 215 | return tf.nn.sigmoid(output), output 216 | 217 | def train_step(self): 218 | wgan_d_loss, gp_loss, d_loss = None, None, None 219 | 220 | # train discriminator 221 | for idx in range(self.flags.num_critic): 222 | batch_imgs = self.dataset.train_next_batch(batch_size=self.flags.batch_size) 223 | dis_feed = {self.z: self.sample_z(num=self.flags.batch_size), self.Y: batch_imgs} 224 | dis_run = [self.dis_optim, self.wgan_d_loss, self.gp_loss, self.d_loss] 225 | _, wgan_d_loss, gp_loss, d_loss = self.sess.run(dis_run, feed_dict=dis_feed) 226 | 227 | # train generator 228 | batch_imgs = self.dataset.train_next_batch(batch_size=self.flags.batch_size) 229 | gen_feed = {self.z: self.sample_z(num=self.flags.batch_size), self.Y: batch_imgs} 230 | _, g_loss, summary = self.sess.run([self.gen_optim, self.g_loss, self.summary_op], feed_dict=gen_feed) 231 | 232 | # negative critic loss 233 | return [-wgan_d_loss, gp_loss, -d_loss, g_loss], summary 234 | 235 | def test_step(self): 236 | return self.sample_imgs() 237 | 238 | def sample_imgs(self, sample_size=64): 239 | g_feed = {self.z: self.sample_z(num=sample_size)} 240 | y_fakes = self.sess.run(self.g_samples, feed_dict=g_feed) 241 | 242 | return [y_fakes] 243 | 244 | def sample_z(self, num=64): 245 | return np.random.uniform(-1., 1., size=[num, self.flags.z_dim]) 246 | 247 | def print_info(self, loss, iter_time): 248 | if np.mod(iter_time, self.flags.print_freq) == 0: 249 | ord_output = collections.OrderedDict([('cur_iter', iter_time), ('tar_iters', self.flags.iters), 250 | ('batch_size', self.flags.batch_size), 251 | ('wgan_d_loss', loss[0]), ('gp_loss', loss[1]), 252 | ('d_loss', loss[2]), ('g_loss', loss[3]), 253 | ('dataset', self.flags.dataset), 254 | ('gpu_index', self.flags.gpu_index)]) 255 | 256 | utils.print_metrics(iter_time, ord_output) 257 | 258 | def plots(self, imgs_, iter_time, save_file): 259 | # reshape image from vector to (N, H, W, C) 260 | imgs_fake = np.reshape(imgs_[0], (self.flags.sample_batch, *self.image_size)) 261 | 262 | imgs = [] 263 | for img in imgs_fake: 264 | imgs.append(img) 265 | 266 | # parameters for plot size 267 | scale, margin = 0.04, 0.01 268 | n_cols, n_rows = int(np.sqrt(len(imgs))), int(np.sqrt(len(imgs))) 269 | cell_size_h, cell_size_w = imgs[0].shape[0] * scale, imgs[0].shape[1] * scale 270 | 271 | fig = plt.figure(figsize=(cell_size_w * n_cols, cell_size_h * n_rows)) # (column, row) 272 | gs = gridspec.GridSpec(n_rows, n_cols) # (row, column) 273 | gs.update(wspace=margin, hspace=margin) 274 | 275 | imgs = [utils.inverse_transform(imgs[idx]) for idx in range(len(imgs))] 276 | 277 | # save more bigger image 278 | for col_index in range(n_cols): 279 | for row_index in range(n_rows): 280 | ax = plt.subplot(gs[row_index * n_cols + col_index]) 281 | plt.axis('off') 282 | ax.set_xticklabels([]) 283 | ax.set_yticklabels([]) 284 | ax.set_aspect('equal') 285 | if self.image_size[2] == 3: 286 | plt.imshow((imgs[row_index * n_cols + col_index]).reshape( 287 | self.image_size[0], self.image_size[1], self.image_size[2]), cmap='Greys_r') 288 | elif self.image_size[2] == 1: 289 | plt.imshow((imgs[row_index * n_cols + col_index]).reshape( 290 | self.image_size[0], self.image_size[1]), cmap='Greys_r') 291 | else: 292 | raise NotImplementedError 293 | 294 | plt.savefig(save_file + '/sample_{}.png'.format(str(iter_time)), bbox_inches='tight') 295 | plt.close(fig) 296 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | skipsdist=True 3 | envlist = py35, py36, py37 4 | 5 | [testenv] 6 | deps = -rrequirements.d/base.txt 7 | --------------------------------------------------------------------------------